pytorch中参数的保存(save),加载操作(load)

    技术2022-07-10  124

    最近写程序,遇到了保存和加载参数的问题,随通过查阅,留下笔记。

    文章目录

    参数的保存参数的加载

    参数的保存

    首先,参数的保存用的是 torch.save(),具体操作:

    for epoch in range(num_epoch): #训练数据集的迭代次数,这里cifar10数据集将迭代2次 train_loss = 0.0 for batch_idx, data in enumerate(trainloader, 0): #初始化 inputs, labels = data #获取数据 optimizer.zero_grad() #先将梯度置为0 #优化过程 outputs = net(inputs) #将数据输入到网络,得到第一轮网络前向传播的预测结果outputs loss = criterion(outputs, labels) #预测结果outputs和labels通过之前定义的交叉熵计算损失 loss.backward() #误差反向传播 optimizer.step() #随机梯度下降方法(之前定义)优化权重 #查看网络训练状态 train_loss += loss.item() if batch_idx % 2000 == 1999: #每迭代2000个batch打印看一次当前网络收敛情况 print('[%d, ]] loss: %.3f' % (epoch + 1, batch_idx + 1, train_loss / 2000)) train_loss = 0.0 print('Saving epoch %d model ...' % (epoch + 1)) #####参数保存########### state = { 'net': net.state_dict(), 'epoch': epoch + 1, } # 1 、 先建立一个字典 if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') # 2 、 建立一个保存参数的文件夹 torch.save(state, './checkpoint/sence15_epoch_%d.ckpt' % (epoch + 1))# 3 、保存操作 # 因为在for epoch in range(num_epoch)这个循环中,所以可以 保存每一个epoch的参数,如果不在这个循环中, #而是循环完成在保存,则保存的是最后一个epoch的参数 print('Finished Training')

    结果如图所示

    参数的加载

    checkpoint = torch.load('./checkpoint/sence15_epoch_60.ckpt')#载入现有模型 net.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch']

    参考链接: https://blog.csdn.net/weixin_38145317/article/details/103582549. 这个链接写的很简单凝练,可以参考

    Processed: 0.010, SQL: 9