【Tensorflow】用于处理checkpoint中参数名称与矩阵数值的工具类

    技术2025-10-12  18

    0x00 前言

    目前对于Tensorflow的模型参数文件,我们处理起来没有Pytorch的参数文件那样方便, 并且现在任务中有个需求,要在“某几个参数矩阵中,将特定行的参数复制到某些其他行”。 Pytorch的话就还好,因为毕竟是一群tensor被一个OrderDict包装起来的Python基本数据结构。 同样的事情,在Tensorflow中处理起来会比较麻烦,于是考虑实现这个工具类 CheckpointMonitor 来提高处理效率。

    0x01 效果及API

    支持从Tensorflow的模型参数文件ckpt中修改任意参数矩阵 可以批量或单独修改参数名,保持参数的各项属性不变 批量修改的方式为:允许传入一个函数,对于输入的参数名均会根据自定义函数修改为输出的参数名称例如,在Tensorflow和PyTorch参数互转的时候,需要用到这一步 可以将修改后的参数存回Tensorflow(下图1)或存成PyTorch(下图2)可以筛选、检查、修改任意参数矩阵的全部或部分数值,对于工具类,全程以numpy的数据格式处理即可自动维护模型文件中的参数顺序,也可以在已有的模型参数基础上做扩充,例如参数拼接

    0x02 API列表

    初始化传参__init__(checkpoint_path)为checkpoint路径list_variables() 展示当前checkpoint中的所有参数即shapelist_target_variables(pattern) 同list_variables,展示筛选后的参数列表(图3)get_var_data(var_name) 获得模型文件中对应参数名的参数,格式为numpysave_model(path, method='tf) 模型文件存回Tensorflow或Pytorchmodify_var_name(old_name, new_name) 修改参数名modify_var_names(rename_func) 批量修改参数名modify_var_data(var_name, var_data) 修改参数的值目前是这些,以后有需求可能会再加(例如加密解密、模型轻量化的工具都可以整合到这个类里)

    0x03 requirements

    python >= 3.6(没测试低版本)tensorflow >= 1.15(没测试低版本)torch >= 1.4 (如果需要存成torch则需要)numpy

    0x04 Source Code

    import os os.environ['CUDA_LAUNCH_BLOCKING'] = "" os.environ['CUDA_VISIBLE_DEVICES'] = "" import numpy as np import tensorflow as tf from collections import OrderedDict class CheckpointMonitor(object): """ # CPU mode import os os.environ['CUDA_LAUNCH_BLOCKING'] = "" os.environ['CUDA_VISIBLE_DEVICES'] = "" """ def __init__(self, checkpoint_path=None): if checkpoint_path is None: # default path for testing checkpoint_path = '/data/sharedata/model_files/model.ckpt-250042' self.saver = None self.graph = None self.dump_path = './' self.checkpoint_path = checkpoint_path self.default_dump_name = 'my_modified_model' self.var_name_list = [] self.var_shape_dict = OrderedDict() self.var_data_dict = OrderedDict() self.init_vars() def reload(self, checkpoint_path=None): self.__init__(checkpoint_path=checkpoint_path) def init_vars(self, checkpoint_path=None): if checkpoint_path is None: checkpoint_path = self.checkpoint_path self.var_shape_dict = OrderedDict( self.list_variables(checkpoint_path)) self.var_name_list = list(self.var_shape_dict.keys()) for var_name in self.var_name_list: # print(var_name) var_data = self.get_var_data(var_name, checkpoint_path) # dict(str, np.array) self.var_data_dict.update({var_name: var_data}) def sort_var_dicts(self): self.var_data_dict = OrderedDict( [(var_name, self.var_data_dict[var_name]) for var_name in self.var_name_list]) self.var_shape_dict = OrderedDict( [(var_name, self.var_shape_dict[var_name]) for var_name in self.var_name_list]) def list_variables(self, checkpoint_path=None): # get all variables in form of tuple(name, shape) in checkpoint if checkpoint_path is None: checkpoint_path = self.checkpoint_path # return a list of (var_name, shape) return tf.contrib.framework.list_variables(checkpoint_path) def list_target_variables(self, pattern, checkpoint_path=None): if checkpoint_path is None: if self.var_shape_dict.__len__() != 0: # lazy loading var_list = self.var_shape_dict.items() return [(name, shape) for (name, shape) in var_list if pattern in name] else: # load for cold-booting checkpoint_path = self.checkpoint_path var_list = self.list_variables(checkpoint_path) return [(name, shape) for (name, shape) in var_list if pattern in name] def get_var_data(self, var_name, checkpoint_path=None): # load variable from target checkpoint with the name as var_name if checkpoint_path is None: if self.var_data_dict.__len__() != 0: # lazy loading return self.var_data_dict.get(var_name) checkpoint_path = self.checkpoint_path # return the variable object (np.array) return tf.contrib.framework.load_variable(checkpoint_path, var_name) @staticmethod def generate_rename_func(old_name_list, new_name_list): def fn(var_name): if var_name in old_name_list: return new_name_list[old_name_list.index(var_name)] return var_name return fn def modify_var_name(self, old_name, new_name, inplace=True): var_index = self.var_name_list.index(old_name) self.var_name_list[var_index] = new_name self.var_data_dict[new_name] = self.var_data_dict[old_name] self.var_shape_dict[new_name] = self.var_shape_dict[old_name] del self.var_data_dict[old_name] del self.var_shape_dict[old_name] if inplace: self.sort_var_dicts() def modify_var_names(self, rename_func=None): # modify var_names in batch, with a feed function `rename_func` if rename_func is None: rename_func = lambda _name: _name with tf.Session() as sess: for var_index, var_name in enumerate(self.var_name_list): # get variable values, in form of np.array new_name = rename_func(var_name) if new_name != var_name: self.modify_var_name(var_index, new_name, inplace=False) print('Re-naming {} to {}.'.format(var_name, new_name)) self.sort_var_dicts() def modify_var_data(self, var_name, var_data): assert isinstance(var_data, np.ndarray) if var_name not in self.var_name_list: print("Invalid variable name:{}".format(var_name)) print("You can get avaliable variable names by calling list_variables()") var_index = self.var_name_list.index(var_name) self.var_shape_dict[var_name] = list(var_data.shape) self.var_data_dict[var_name] = var_data def generate_var_dict_for_torch(self, var_list=None): if var_list is None: var_list = self.var_data_dict.items() torch_model_dict = OrderedDict() for var_name, var_data in var_list: var = torch.tensor(var_data) torch_model_dict.update({var_name: var}) return torch_model_dict def generate_var_list_for_saver(self, var_list=None): if var_list is None: var_list = self.var_data_dict.items() saver_var_list = [] with tf.Session() as sess: for var_name, var_data in var_list: var = tf.Variable(var_data, name=var_name) saver_var_list.append(var) return saver_var_list def save_model(self, new_checkpoint_path=None, model_name=None, method='pt'): if new_checkpoint_path is None: new_checkpoint_path = self.dump_path if not os.path.exists(new_checkpoint_path): os.makedirs(new_checkpoint_path) if model_name is None: model_name = self.default_dump_name checkpoint_path = os.path.join( new_checkpoint_path, model_name) method_dict = { 'pt': self.save_model_as_pt, 'tf': self.save_model_as_tf, 'ckpt': self.save_model_as_tf, 'torch': self.save_model_as_pt, 'pytorch': self.save_model_as_pt, 'tensorflow': self.save_model_as_tf, } method_dict[method](checkpoint_path) def save_model_as_pt(self, checkpoint_path): import torch var_dict = self.generate_var_dict_for_torch() checkpoint = OrderedDict({'model': var_dict}) torch.save(checkpoint, checkpoint_path + '.pt') print("Checkpoint saving finished !\n{}".format( checkpoint_path + '.pt')) def save_model_as_tf(self, checkpoint_path): with tf.Session() as sess: var_list = self.generate_var_list_for_saver() # Construct the Saver self.saver = tf.train.Saver(var_list=var_list) # Necessary! Call the initializer at the beginning. sess.run(tf.global_variables_initializer()) self.saver.save(sess, checkpoint_path) print("Checkpoint saving finished !\n{}".format( checkpoint_path))

    0x05 效果展示

    图1 读取原TF模型→修改单值→存回→读取新TF模型→检查修改

    图2 读取原TF模型→修改单值→存成Pytorch模型→读取新PyTorch模型→检查修改

    Processed: 0.010, SQL: 9