bert得到向量

    技术2022-07-17  77

    modeling,tokenization为bert在github上的代码,链接: https://github.com/google-research/bert.

    chinese_L-12_H-768_A-12是中文语料训练的模型,链接: https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip.

    from bert_demo import modeling from bert_demo import tokenization import numpy as np import tensorflow as tf class bert_vec(): def __init__(self): # graph self.input_ids = tf.placeholder(tf.int32, shape=[None, None], name='input_ids') self.input_mask = tf.placeholder(tf.int32, shape=[None, None], name='input_masks') self.segment_ids = tf.placeholder(tf.int32, shape=[None, None], name='segment_ids') bert_config = modeling.BertConfig.from_json_file('chinese_L-12_H-768_A-12/bert_config.json') # 初始化BERT self.model = modeling.BertModel( config=bert_config, is_training=False, input_ids=self.input_ids, input_mask=self.input_mask, token_type_ids=self.segment_ids, use_one_hot_embeddings=False ) # bert模型地址 init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt" # 模型的训练参数 tvars = tf.trainable_variables() # 加载模型 (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) def get_embedding(self,char_lists,mask_lists,seg_lists): # 获取最后一层和倒数第二层 encoder_last_layer = self.model.get_sequence_output() # encoder_last2_layer = model.all_encoder_layers[-2] feed_data = {self.input_ids: np.asarray(char_lists), self.input_mask: np.asarray(mask_lists), self.segment_ids: np.asarray(seg_lists)} embedding = self.sess.run(encoder_last_layer, feed_dict=feed_data) return embedding if __name__ == '__main__': #数据处理 string = '设置一个随机种子' char_list = ['[CLS]'] + list(string) +['[SEP]'] #不做masked处理 mask_list = [1] * (len(string)+2) #不做分词处理 seg_list = [0] * (len(string)+2) # 根据bert的词表做一个char_to_id的操作 # 未登录词会报错,更改报错代码使未登录词时为'[UNK]' # 也可以自己实现 token = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt') char_list = token.convert_tokens_to_ids(char_list) bertVec = bert_vec() #得到bert的embedding embedding = bertVec.get_embedding([char_list], [mask_list], [seg_list])
    Processed: 0.014, SQL: 9