0x00 前言
 
目前对于Tensorflow的模型参数文件,我们处理起来没有Pytorch的参数文件那样方便, 并且现在任务中有个需求,要在“某几个参数矩阵中,将特定行的参数复制到某些其他行”。 Pytorch的话就还好,因为毕竟是一群tensor被一个OrderDict包装起来的Python基本数据结构。 同样的事情,在Tensorflow中处理起来会比较麻烦,于是考虑实现这个工具类 CheckpointMonitor 来提高处理效率。
 
0x01 效果及API
 
支持从Tensorflow的模型参数文件ckpt中修改任意参数矩阵 
  可以批量或单独修改参数名,保持参数的各项属性不变 
    批量修改的方式为:允许传入一个函数,对于输入的参数名均会根据自定义函数修改为输出的参数名称例如,在Tensorflow和PyTorch参数互转的时候,需要用到这一步 可以将修改后的参数存回Tensorflow(下图1)或存成PyTorch(下图2)可以筛选、检查、修改任意参数矩阵的全部或部分数值,对于工具类,全程以numpy的数据格式处理即可自动维护模型文件中的参数顺序,也可以在已有的模型参数基础上做扩充,例如参数拼接  
0x02 API列表
 
初始化传参__init__(checkpoint_path)为checkpoint路径list_variables() 展示当前checkpoint中的所有参数即shapelist_target_variables(pattern) 同list_variables,展示筛选后的参数列表(图3)get_var_data(var_name) 获得模型文件中对应参数名的参数,格式为numpysave_model(path, method='tf) 模型文件存回Tensorflow或Pytorchmodify_var_name(old_name, new_name) 修改参数名modify_var_names(rename_func) 批量修改参数名modify_var_data(var_name, var_data) 修改参数的值目前是这些,以后有需求可能会再加(例如加密解密、模型轻量化的工具都可以整合到这个类里) 
0x03 requirements
 
python >= 3.6(没测试低版本)tensorflow >= 1.15(没测试低版本)torch >= 1.4 (如果需要存成torch则需要)numpy 
0x04 Source Code
 
import os
os
.environ
['CUDA_LAUNCH_BLOCKING'] = ""
os
.environ
['CUDA_VISIBLE_DEVICES'] = ""
import numpy 
as np
import tensorflow 
as tf
from collections 
import OrderedDict
class CheckpointMonitor(object):
    """
    # CPU mode
    import os
    os.environ['CUDA_LAUNCH_BLOCKING'] = ""
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    """
    def __init__(self
, checkpoint_path
=None):
        if checkpoint_path 
is None:  
            checkpoint_path 
= '/data/sharedata/model_files/model.ckpt-250042'
        
        self
.saver 
= None
        self
.graph 
= None
        self
.dump_path 
= './'
        self
.checkpoint_path 
= checkpoint_path
        self
.default_dump_name 
= 'my_modified_model'
        self
.var_name_list 
= []
        self
.var_shape_dict 
= OrderedDict
()
        self
.var_data_dict 
= OrderedDict
()
        self
.init_vars
()
    
    def reload(self
, checkpoint_path
=None):
        self
.__init__
(checkpoint_path
=checkpoint_path
)
    
    def init_vars(self
, checkpoint_path
=None):
        if checkpoint_path 
is None:
            checkpoint_path 
= self
.checkpoint_path
        self
.var_shape_dict 
= OrderedDict
(
            self
.list_variables
(checkpoint_path
))
        self
.var_name_list 
= list(self
.var_shape_dict
.keys
())
        for var_name 
in self
.var_name_list
:
            
            var_data 
= self
.get_var_data
(var_name
, checkpoint_path
)
            
            self
.var_data_dict
.update
({var_name
: var_data
})
    
    def sort_var_dicts(self
):
        self
.var_data_dict 
= OrderedDict
(
            [(var_name
, self
.var_data_dict
[var_name
]) 
             for var_name 
in self
.var_name_list
])
        self
.var_shape_dict 
= OrderedDict
(
            [(var_name
, self
.var_shape_dict
[var_name
]) 
             for var_name 
in self
.var_name_list
])
    
    def list_variables(self
, checkpoint_path
=None):
        
        if checkpoint_path 
is None:
            checkpoint_path 
= self
.checkpoint_path
        
        
return tf
.contrib
.framework
.list_variables
(checkpoint_path
)
    
    def list_target_variables(self
, pattern
, checkpoint_path
=None):
        if checkpoint_path 
is None:
            if self
.var_shape_dict
.__len__
() != 0:
                
                var_list 
= self
.var_shape_dict
.items
()
                return [(name
, shape
) for (name
, shape
) 
                        in var_list 
if pattern 
in name
]
            else:  
                checkpoint_path 
= self
.checkpoint_path
        var_list 
= self
.list_variables
(checkpoint_path
)
        return [(name
, shape
) for (name
, shape
) in var_list 
if pattern 
in name
]
    
    def get_var_data(self
, var_name
, checkpoint_path
=None):
        
        if checkpoint_path 
is None:
            if self
.var_data_dict
.__len__
() != 0:
                
                return self
.var_data_dict
.get
(var_name
)
            checkpoint_path 
= self
.checkpoint_path
        
        
return tf
.contrib
.framework
.load_variable
(checkpoint_path
, var_name
)
    
    @
staticmethod
    def generate_rename_func(old_name_list
, new_name_list
):
        def fn(var_name
):
            if var_name 
in old_name_list
:
                return new_name_list
[old_name_list
.index
(var_name
)]
            return var_name
        
return fn
    
    
def modify_var_name(self
, old_name
, new_name
, inplace
=True):
        var_index 
= self
.var_name_list
.index
(old_name
)
        self
.var_name_list
[var_index
] = new_name
        self
.var_data_dict
[new_name
] = self
.var_data_dict
[old_name
]
        self
.var_shape_dict
[new_name
] = self
.var_shape_dict
[old_name
]
        del self
.var_data_dict
[old_name
]
        del self
.var_shape_dict
[old_name
]
        if inplace
:
            self
.sort_var_dicts
()
    
    def modify_var_names(self
, rename_func
=None):
        
        if rename_func 
is None:
            rename_func 
= lambda _name
: _name
        
with tf
.Session
() as sess
:
            for var_index
, var_name 
in enumerate(self
.var_name_list
): 
                
                new_name 
= rename_func
(var_name
)
                if new_name 
!= var_name
:
                    self
.modify_var_name
(var_index
, new_name
, inplace
=False)
                    print('Re-naming {} to {}.'.format(var_name
, new_name
))
            self
.sort_var_dicts
()
    
    def modify_var_data(self
, var_name
, var_data
):
        assert isinstance(var_data
, np
.ndarray
)
        if var_name 
not in self
.var_name_list
:
            print("Invalid variable name:{}".format(var_name
))
            print("You can get avaliable variable names by calling list_variables()")
        var_index 
= self
.var_name_list
.index
(var_name
)
        self
.var_shape_dict
[var_name
] = list(var_data
.shape
)
        self
.var_data_dict
[var_name
] = var_data
    
    
def generate_var_dict_for_torch(self
, var_list
=None):
        if var_list 
is None:
            var_list 
= self
.var_data_dict
.items
()
        torch_model_dict 
= OrderedDict
()
        for var_name
, var_data 
in var_list
:
            var 
= torch
.tensor
(var_data
)
            torch_model_dict
.update
({var_name
: var
})
        return torch_model_dict
    
    
def generate_var_list_for_saver(self
, var_list
=None):
        if var_list 
is None:
            var_list 
= self
.var_data_dict
.items
()
        saver_var_list 
= []
        with tf
.Session
() as sess
:
            for var_name
, var_data 
in var_list
:
                var 
= tf
.Variable
(var_data
, name
=var_name
)
                saver_var_list
.append
(var
)
        return saver_var_list
    
    
def save_model(self
, new_checkpoint_path
=None, model_name
=None, method
='pt'):
        if new_checkpoint_path 
is None:
            new_checkpoint_path 
= self
.dump_path
        
if not os
.path
.exists
(new_checkpoint_path
):
            os
.makedirs
(new_checkpoint_path
)
        if model_name 
is None:
            model_name 
= self
.default_dump_name
        checkpoint_path 
= os
.path
.join
(
            new_checkpoint_path
, model_name
)
        
        method_dict 
= {
            'pt': self
.save_model_as_pt
,
            'tf': self
.save_model_as_tf
,
            'ckpt': self
.save_model_as_tf
,
            'torch': self
.save_model_as_pt
,
            'pytorch': self
.save_model_as_pt
,
            'tensorflow': self
.save_model_as_tf
,
        }
        method_dict
[method
](checkpoint_path
)
    
    def save_model_as_pt(self
, checkpoint_path
):
        import torch
        var_dict 
= self
.generate_var_dict_for_torch
()
        checkpoint 
= OrderedDict
({'model': var_dict
})
        torch
.save
(checkpoint
, checkpoint_path 
+ '.pt')
        print("Checkpoint saving finished !\n{}".format(
            checkpoint_path 
+ '.pt'))
    
    def save_model_as_tf(self
, checkpoint_path
):
        with tf
.Session
() as sess
:
            var_list 
= self
.generate_var_list_for_saver
()
            
            self
.saver 
= tf
.train
.Saver
(var_list
=var_list
)
            
            sess
.run
(tf
.global_variables_initializer
())
            self
.saver
.save
(sess
, checkpoint_path
)
            print("Checkpoint saving finished !\n{}".format(
                checkpoint_path
))
 
0x05 效果展示
 
图1 读取原TF模型→修改单值→存回→读取新TF模型→检查修改
 
 
图2 读取原TF模型→修改单值→存成Pytorch模型→读取新PyTorch模型→检查修改