Double DQN with Prioritized Experience Reply

    技术2026-03-20  19

    Double DQN with Prioritized Experience Reply

    Double DQN部分Prioritized Experience Reply部分其他compute_nstep_return()函数 主要是学习了一下 tianshou的源代码来增加一些水平。看看好的开源库总是不错的。记录一下学到的知识。

    Double DQN部分

    target net一直是eval的只需要将new net的更新从 r + γ Q ′ ( s ′ , a r g m a x ( Q ′ ( s ′ , a ′ ) ) ) r + \gamma Q^{'}(s', argmax(Q'(s', a'))) r+γQ(s,argmax(Q(s,a))) 变成 r + γ Q ′ ( s ′ , a r g m a x ( Q ( s ′ , a ′ ) ) ) r + \gamma Q^{'}(s', argmax(Q(s', a'))) r+γQ(s,argmax(Q(s,a))) 即可。并不需要弄个 Q A Q_A QA Q B Q_B QB一起更新啥的,还是new和old两个相对的网络。

    Prioritized Experience Reply部分

    虽然sum tree的结构很神奇,但是tianshou实现的时候用了很简单的indice = np.random.choice( self._size, batch_size, p=(self.weight / self.weight.sum())[:self._size], replace=self._replace) 直接用p这个参数搞定了。

    另外我感觉这里他没有把important sampling的结果归一化?impt_weight = Batch( impt_weight=1 / np.power( self._size * (batch.weight / self._weight_sum), self._beta))

    其他

    compute_nstep_return()函数

    首先,buffer是一个顺序存储结构,从不乱序

    然后唯一的限制就是说一定是最后一个episode完整跑完了再做,即使它从前面覆盖了第一个episode也行

    这个buffer占满了之后就从第一个位置开始继续覆盖放了

    然后,indice是sample的样本在buffer中的坐标,是这个代码体系中从buffer sample后自带的

    从每个indice+n_step开始算reward,倒着乘以gamma累加

    然后万一开始的indice是一个被覆盖了的episode,当走到done这个位置时,

    令return=0,也就是刚刚算的done之后的都不要了,相当于最后算的是新的: ----------|<-倒着算–| -------done----------- ----------|<-不要了->| 不要也就是下面那个reward=0

    代码中的gammas是说最后算完之后q = return + target_q * gamma的gamma

    因为有done 的情况下,不一定都是n,所以特殊算了一下

    returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + n_step done, rew, buf_len = buffer.done, buffer.rew, len(buffer) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len target_q = target_q_fn(buffer, terminal).squeeze() target_q[gammas != n_step] = 0 returns = to_torch_as(returns, target_q) gammas = to_torch_as(gamma ** gammas, target_q) batch.returns = target_q * gammas + returns return batch

    Processed: 0.008, SQL: 9