【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

    技术2022-07-16  74

    本文用了强化学习,在知识图谱上游走,寻找目标节点。

    一、简介

    大概意思就是,在知识图谱上,给出一个起始节点和查询(query),然后找到目标节点。

     图G包含节点和边。

    如下图,给出起始节点Obama,query:citizenship,目标节点是USA。

     

     

    我们要学习一个方法来预测。

    我们我们将f作为强化学习的agent。他要学习搜索策略(search policy)

    训练的时候,我们给出,让f自己学习路径,如果他走到,就给他一个正的reward(1分),其他时候是0分(没停或者停错地方都是0)。学完后只给出,预测。(这个reward只用在Q learning的时候了)

     

    所以设计了一个神经网络的agent,叫M-walk。用RNN将历史路径转化为一个向量,用来学policy和Q function 。reward稀疏,所以用带蒙特卡洛树搜索的RNN,生成路径。

     

    二、用马尔科夫决策过程来进行图的游走

    (S,A,R,P) s是state,a是action,r是reward function,p是state transition probability

    初始状态s0和下一个状态的表示,如上图所示。

    是连接点nt的所有边,是nt的所有邻居节点。

    st包括1)到t时刻所有走过的节点(包括他们的邻居和邻边) 2)动作 3)初始query q构成。

     

    集合S由所有可能出现的st构成。

    在状态st,agent有以下动作可以选择:1)选择中的一条边,他连接到点 2)选择STOP,则就是要预测的。通常是随着时间t而改变的。

    t时刻的动作集合由下图表示,A是所有时刻的At的并集。

    选择stop之后,输出

     

    如果输出是(即输出了正确的答案),则reward=1,否则为0.

    这可以看出来,reward是非常稀疏的,只有走到正确的位置才有reward。但是由于图是已知静态确定的,所以如果确定了上一个状态和动作,那么下一个状态时确定的。(文中说这有助于解决reward稀疏。)

     

    π是policy(给出状态s,选择动作a),Q是Q function(在状态s下选择动作a,它的Q value是多少,即之后的长期收益是多少)

     

    三、M-walk agent

    3.1 π和Q的神经网路结构

    用RNN获得当前状态st的表达ht

    ht分为三个部分:

    1) 将上个时间的状态、动作、当前节点,综合。

    2)综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(不包括STOP动作)

    3)  综合了和,用来判断STOP的概率。

     

    所以π和Q的计算。

    u0是将hst,hAt通过一个full-connected neural network。(这里没说这两个h要怎么整合到一起,应该是拼接)

    un'是hst和hn't做内积(即点乘,对应位相乘,求和)

    u0(STOP的分数),un'(邻居的分数)都是一个数字

    Q是对每个数字做sigmoid

    (这里做sigmoid,将q value化到0-1,因为这个模型的分数只有0和1,q value=0代表在当前s采取a,预期的总reward是0,是找不到的,如果是1代表未来可能找到。)

    π是做温度参数为τ的softmax

    关于温度参数

     

    3.2 训练算法

    传统的使用蒙特卡罗方法的REINFORCE,需要sample一个完整的序列,sample的效率很低,而且reward稀疏。所以sample的时候使用PUCT算法的变体。

     

    π是上面提到的策略分数(softmax算的),c和β用来控制探索的程度。N是visit count。W是走(s-a)这条边上的蒙特卡罗树的total action reward。

    PUCT算法最开始倾向于选择在状态s下出现少的action(式子的前半部分), 后来倾向于选择分数高的(式子的后半部分)。

    当PUCT算法选择了STOP,或者到达了最大探索数(应该是强行选择STOP),则停止。使用

    用下面的式子,更新上一个式子中的N和W。γ是衰减因子(discount factor).

    主要目标就是多生成reward为正的路径。

    然后用DQN网络,寻找更好的π就是max Q

    (由于Q和π共享参数,且算的时候只用了sigmoid和softmax这种没参数的函数,所以训一个就行)

    莫烦python-DQN网络代码详解(pytorch)

     3.3预测算法

    已知(ns,q)求nT。利用π在G上寻找nT。

    一种方法是用训练好的π去寻找。然而这并没有用MDP的转移模型(?)(下方这个公式)

    所以利用上面训练好的模型π、Q去生成蒙特卡罗树,就像训练时那样(Q stop作为上文提到的V进行更新)。但是可能有多路径到达同一个终止节点n。走不同路径,就有不同的叶子节点是n。

    怎么比较选择哪个终点n(而且n需要综合多条路径),需要算一个分数,排序。

    N是蒙特卡罗树的总模拟数量

    综合叶子节点是n的情况,求n的分数。

    在所有的候选节点中,我们选择score最大的。

     

    3.4 RNN encoder

     qt约等于右边的式子

    所以st大约可以写成

    st由两部分组成 1)  代表候选动作(包括STOP) 2)qt代表历史

    所以用两个不同的神经网络去编码他们

    前面说过,ht分为三个部分:

    1) 将上个时间的状态、动作、当前节点,综合。

    2)综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(包括STOP动作)

    3)  综合了和,用来判断STOP的概率。

    求 2)的方法很简单,就是边和点的表达通过full-connected neural network

    求 3) 的方法,就是max 2)的结果,因为每一次的节点数可能都不一样,这样可以得到统一的结果

    求1) 就是编码qt 使用gru的思想

    可以看出,q就相当于rnn里的hidden,初始是query q,之后为qt,rnn的输入是

     

    Processed: 0.015, SQL: 9