虽然sum tree的结构很神奇,但是tianshou实现的时候用了很简单的indice = np.random.choice( self._size, batch_size, p=(self.weight / self.weight.sum())[:self._size], replace=self._replace) 直接用p这个参数搞定了。
另外我感觉这里他没有把important sampling的结果归一化?impt_weight = Batch( impt_weight=1 / np.power( self._size * (batch.weight / self._weight_sum), self._beta))
首先,buffer是一个顺序存储结构,从不乱序
然后唯一的限制就是说一定是最后一个episode完整跑完了再做,即使它从前面覆盖了第一个episode也行
这个buffer占满了之后就从第一个位置开始继续覆盖放了
然后,indice是sample的样本在buffer中的坐标,是这个代码体系中从buffer sample后自带的
从每个indice+n_step开始算reward,倒着乘以gamma累加
然后万一开始的indice是一个被覆盖了的episode,当走到done这个位置时,
令return=0,也就是刚刚算的done之后的都不要了,相当于最后算的是新的: ----------|<-倒着算–| -------done----------- ----------|<-不要了->| 不要也就是下面那个reward=0
代码中的gammas是说最后算完之后q = return + target_q * gamma的gamma
因为有done 的情况下,不一定都是n,所以特殊算了一下
returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + n_step done, rew, buf_len = buffer.done, buffer.rew, len(buffer) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len target_q = target_q_fn(buffer, terminal).squeeze() target_q[gammas != n_step] = 0 returns = to_torch_as(returns, target_q) gammas = to_torch_as(gamma ** gammas, target_q) batch.returns = target_q * gammas + returns return batch