本文涉及到的是中国大学慕课《人工智能实践:Tensorflow笔记》第四讲第四节的内容,通过tensorflow实现神经网络模型的断点续训。
神经网络模型的断点续训指的是将训练好的模型保存下来,并在之后的运行中直接调用的训练方法。
保存模型使用以下代码,将模型参数保存到callbacks中,并将callbacks添加到模型训练的history中。
callbacks = tf.keras.callbacks.ModelCheckpoint( ( filepath= = 路径文件名, save_weights_only= = True/False, save_best_only= = True/False) ) history = model.fit( ( callbacks=[cp_callback] )读取模型使用以下代码,基本原理是生成.ckpt文件的同时会生成对应的索引表(.index文件),通过判断索引表是否存在来决定是否导入保存的模型。
checkpoint_save_path = "./checkpoint/mnist.ckpt" # 给出模型保存的路径以及文件名 if os.path.exists(checkpoint_save_path + '.index'): # 通过索引表判断保存的模型是否存在 print('-------------load the model-----------------') # 是,则打印"导入模型" model.load_weights(checkpoint_save_path) # 导入模型从神经网络搭建的六步法来看,与DL with python(6)——Keras实现手写数字识别(全连接网络)中直接导入mnist数据的代码相比,断点续训的代码在第一步、第四步和第五步有所改动。
# 第一步,导入相关模块,os模块用于判断文件是否存在 import tensorflow as tf import os # 第二步,导入数据集 mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # 第三步,搭建网络结构 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) # 第四步,配置训练方法 model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) # 导入保存的模型,第二次运行才可以进行的操作 checkpoint_save_path = "./checkpoint/mnist.ckpt" # 给出模型保存的路径以及文件名 if os.path.exists(checkpoint_save_path + '.index'): # 通过索引表判断保存文件是否存在 print('-------------load the model-----------------') # 是,则打印"导入模型" model.load_weights(checkpoint_save_path) # 导入模型 # 保存模型,第一次运行执行这一步操作 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, # 模型保存路径 save_weights_only=True, # 只保留模型参数 save_best_only=True) # 只保留最优结果 # 第五步,执行训练,依次为训练集样本,训练集标签,小批量大小32,训练轮次5,测试集,训练集循环1轮次进行一次测试 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) # 最后添加callbacks回调选项,将前面保存的模型参数赋给后面的模型 # 第六步,打印网络结构和参数统计 model.summary()第一次运行代码后,会得到一个checkpoint文件夹,其中含有四个文件,含有模型的相关信息。 第一次运行,模型的最终表现
60000/60000 [==============================] - 5s 77us/sample - loss: 0.0454 - sparse_categorical_accuracy: 0.9863 - val_loss: 0.0865 - val_sparse_categorical_accuracy: 0.9752然后第二次运行,在第一次的基础上进行训练,最终表现如下,可以看到各方面指标都有了很大的进步。
60000/60000 [==============================] - 4s 69us/sample - loss: 0.0180 - sparse_categorical_accuracy: 0.9943 - val_loss: 0.0785 - val_sparse_categorical_accuracy: 0.9782