tensorflow得到图的所有节点名称以及得到节点输出

    技术2023-11-25  92

    # 默认图的所有节点名称 # tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] #图graph的所有节点名称 # tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node] # print(tensor_name_list) #数据处理 string = '设置一个随机种子' char_list = ['[CLS]'] + list(string) +['[SEP]'] #不做masked处理 mask_list = [1] * (len(string)+2) #不做分词处理 seg_list = [0] * (len(string)+2) # 根据bert的词表做一个char_to_id的操作 # 未登录词会报错,更改报错代码使未登录词时为'[UNK]' # 也可以自己实现 token = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt') char_list = token.convert_tokens_to_ids(char_list) char_lists = [char_list] mask_lists = [mask_list] seg_lists = [seg_list] input_ids = sess.graph.get_tensor_by_name('input_ids:0') input_mask = sess.graph.get_tensor_by_name('input_masks:0') segment_ids = sess.graph.get_tensor_by_name('segment_ids:0') # bert12层transformer,取最后一层的输出 output = sess.graph.get_tensor_by_name('bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1:0') feed_data = {input_ids: np.asarray(char_lists), input_mask: np.asarray(mask_lists), segment_ids: np.asarray(seg_lists)} embedding = sess.run(output, feed_dict=feed_data) #bert输出向量结果分批次没有节点,这里reshape成bert_model.get_sequence_output()的形状 embedding = np.reshape(embedding, (len(char_lists), len(char_lists[0]), -1))
    Processed: 0.013, SQL: 9