【Pytorch】Model存储

    技术2022-07-14  76

    1. 保存加载模型权重

    pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

    # 保存模型示例代码 print('===> Saving models...') state = { 'state': model.state_dict(), 'epoch': epoch # 将epoch一并保存 } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/autoencoder.t7')

    pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

    print('===> Try resume from checkpoint') if os.path.isdir('checkpoint'): try: checkpoint = torch.load('./checkpoint/autoencoder.t7') model.load_state_dict(checkpoint['state']) # 从字典中依次读取 start_epoch = checkpoint['epoch'] print('===> Load last checkpoint data') except FileNotFoundError: print('Can\'t found autoencoder.t7') else: start_epoch = 0 print('===> Start from scratch') params=model.state_dict() for k,v in params.items(): print(k) #打印网络中的变量名 print(params['conv1.weight']) #打印conv1的weight print(params['conv1.bias']) #打印conv1的bias

    net[“model”] 详解:键model所对应的值是一个OrderedDict,而这个OrderedDict字典里面又存储着所有的每一层的参数名称以及对应的参数值。

    net[“optimizer”]详解:只有两个key,一个是state,一个是param_groups 其中state所对应的值又是一个字典类型,param_groups对应的值是一个列表。

    net[“scheduler”] 详解, net[“iteration”] 详解

    Processed: 0.015, SQL: 9