bert的原理我在第一篇论文介绍了,不赘述 下面写一下我对bert embedding 和XLNet embedding的理解与两者之间的不同 bert embedding 可选择的预,L表示的是transformer的层数,H表示输出的维度,A表示mutil-head attention的个数训练模型,每一层transformer的输出值,理论上来说都可以作为句向量,但是到底应该取哪一层呢,根据hanxiao大神的实验数据,最佳结果是取倒数第二层,最后一层的值太接近于目标,前面几层的值可能语义还未充分的学习到。 xlnet embedding XLNet在文中指出的,第一个预训练阶段因为采取引入[Mask]标记来Mask掉部分单词的训练模式,而Fine-tuning阶段是看不到这种被强行加入的Mask标记的,所以两个阶段存在使用模式不一致的情形,这可能会带来一定的性能损失;另外一个是,Bert在第一个预训练阶段,假设句子中多个单词被Mask掉,这些被Mask掉的单词之间没有任何关系,是条件独立的,而有时候这些单词之间是有关系的,XLNet则考虑了这种关系. 为了解决上文提到的问题,作者提出了排列语言模型,该模型不再对传统的AR模型的序列的值按顺序进行建模,而是最大化所有可能的序列的因式分解顺序的期望对数似然。 这里简单说下“双流自注意力机制”,一个是内容流自注意力,其实就是标准的Transformer的计算过程;主要是引入了Query流自注意力,这个是干嘛的呢?其实就是用来代替Bert的那个[Mask]标记的,因为XLNet希望抛掉[Mask]标记符号,但是比如知道上文单词x1,x2,要预测单词x3,此时在x3对应位置的Transformer最高层去预测这个单词,但是输入侧不能看到要预测的单词x3,Bert其实是直接引入[Mask]标记来覆盖掉单词x3的内容的,等于说[Mask]是个通用的占位符号。而XLNet因为要抛掉[Mask]标记,但是又不能看到x3的输入,于是Query流,就直接忽略掉x3输入了,只保留这个位置信息,用参数w来代表位置的embedding编码。其实XLNet只是扔了表面的[Mask]占位符号,内部还是引入Query流来忽略掉被Mask的这个单词。和Bert比,只是实现方式不同而已。 除了上文提到的优化点,作者还将transformer-xl的两个最重要的技术点应用了进来,即相对位置编码与片段循环机制。
核心代码 加载与训练模型,根据需要选择第几层weight
def __init__(self): self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len # 全局使用,使其可以django、flask、tornado等调用 global graph graph = tf.get_default_graph() global model model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path, seq_len=self.max_seq_len) print(model.output) print(len(model.layers)) # lay = model.layers #一共104个layer,其中前八层包括token,pos,embed等, # 每8层(MultiHeadAttention,Dropout,Add,LayerNormalization) resnet # 一共12层 layer_dict = [7] layer_0 = 7 for i in range(12): layer_0 = layer_0 + 8 layer_dict.append(layer_0) print("kkkkkkkkk") print(layer_indexes) print(len(layer_indexes)) # 输出它本身 if len(layer_indexes) == 0: encoder_layer = model.output # 分类如果只有一层,就只取最后那一层的weight,取得不正确 elif len(layer_indexes) == 1: if layer_indexes[0] in [i+1 for i in range(13)]: encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output else: encoder_layer = model.get_layer(index=layer_dict[-1]).output # 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数 else: # layer_indexes must be [1,2,3,......13] # all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes] all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(13)] else model.get_layer(index=layer_dict[-1]).output # 如果给出不正确,就默认输出最后一层 for lay in layer_indexes] print(layer_indexes) print(all_layers) # 其中layer==1的output是格式不对,第二层输入input是list all_layers_select = [] for all_layers_one in all_layers: all_layers_select.append(all_layers_one) encoder_layer = Add()(all_layers_select) print(encoder_layer.shape) print("KerasBertEmbedding:") print(encoder_layer.shape) output_layer = NonMaskingLayer()(encoder_layer) model = Model(model.inputs, output_layer) # model.summary(120) # reader tokenizer self.token_dict = {} with codecs.open(self.dict_path, 'r', 'utf8') as reader: for line in reader: token = line.strip() self.token_dict[token] = len(self.token_dict) self.tokenizer = Tokenizer(self.token_dict)对最终的weight进行处理,把填充的地方设置为0,这里定义了两个方法,一个mul_mask 和一个masked_reduce_mean,我们先看masked_reduce_mean(encoder_layer, input_mask)这里调用方法时传入的是encoder_layer即输出值,与input_mask即是否有有效文本,masked_reduce_mean方法中又调用了mul_mask方法,即先把input_mask进行了一个维度扩展,然后与encoder_layer相乘,为什么要维度扩展呢,我们看下两个值的维度,我们还是假设序列的最大长度是20,那么encoder_layer的维度为[20,768],为了把无效的位置的内容置为0,input_mask的维度为[20],扩充之后变成了[20,1],两个值相乘,便把input_mask为0的位置的encoder_layer的值改为了0, 然后把相乘得到的值在axis=1的位置进行相加最后除以input_mask在axis=1的维度的和
mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1) masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9) pools = [] for i in range(len(predicts)): pred = predicts[i] masks = input_masks.tolist() mask_np = np.array([masks[i]]) pooled = masked_reduce_mean(pred, mask_np) pooled = pooled.tolist() pools.append(pooled[0]) print('bert:', pools)xlnet同理
最终结果 bert某个词的词向量与总体相似度 xlnet词embedding相似度 参考 https://blog.csdn.net/hwaust2020/article/details/106785098/?ops_request_misc=&request_id=&biz_id=102&utm_term=bertmask缺点&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-3-106785098