很多小伙伴在使用TensorFlow2.x的时候会进行自定义的循环,也就是自己采用for循环来逐个Epoch循环;同时又想将此时的网络图绘制在TensorBoard中。这个时候问题就出现了:TensorBoard在2.0以后的版本中的的网络图是默认在model.fit之中自动绘制的;
# 使用fit函数的时候会自动绘制网络计算图 model.fit(trrain_dataset, epoch=10, ......)倘若想要自定义训练循环则又需要手动绘制网络图。
# 自定义寻来你循环的时候,TensorFlow不会帮助我们绘制网络计算图 for epooch in range(1, EPOCHS): SDG... LOSS... Record...而网络上关于TensorFlow2.x绘制网络图的说明是少之又少,于是我决定写这篇博客来帮助大家来实现网络图的绘制。
直接给大家展示代码
import tensorflow as tf from tensorflow.keras.layers import Dense, Flatten, Conv2D from tensorflow.keras.datasets import mnist from tensorflow.python.ops import summary_ops_v2 # 需要引入这个模块 logs_dir='你的自定义的日志目录' # 你创建的模型 class ClassModel(tf.keras.Model): def __init__(self, ...): super(ClassModel, self).__init__() self.d1 = Dense(128, activation='relu') self.d2 = Dense(self.num_classes, activation='softmax') ... # 其他操作 @tf.function # 需要使用tf.function def call(self, inputs): inputs = self.d1(inputs) output = self.d2(inputs) return output # inputs可以是符合你输入数据形状的输入数据 inputs=training_dataset model=ClassModel() # 开始创建网络计算图 graph_writer = tf.summary.create_file_writer(logdir=logs_dir) with graph_writer.as_default(): graph=model.call.get_concrete_function(inputs).graph summary_ops_v2.graph(graph.as_graph_def()) graph_writer.close()通过这个流程,就可以构建出你的网络模型图了。 在这个过程中,有几点注意事项
from tensorflow.python.ops import summary_ops_v2 需要引入这个模块自定义模型中的call需要使用tf.function注解标注inputs可以为任何符合网络输入形状的数据,比如我的网络输入为(None, 32, 32, 3),那么我就可以令inputs=tf.ones((64, 32, 32, 3)),也就是说可以使用该数据跑通这个模型即可使用tf.summary的FileWriter来进行绘制绘制结果可以在TensorBoard的URL之中查看:
其实这也是笔者找了很多文档都没发现,然后自己研究出来的方法。希望可以帮到大家。如果大家有任何问题,可以添加笔者QQ进行讨论:1574143668. 请大家在学习与工作的过程中不要忘记互联网创立的初衷——分享。