Published on

论文笔记 Trust Region Policy Optimisation

Authors

作为强化学习理论的入门文章,TRPO这篇的idea非常简单,但论证之繁琐,吓退了一大批爱好者,没有办法,学术圈就喜欢搞这种文章。

要读懂TRPO,有必要读一下它的前作:

论文笔记 Approximately Optimal Approximate Reinforcement Learning

有了这篇的基础,在来看TRPO就简单多了,可谓是新瓶装旧酒之作。

对任意两个策略 π,π^\pi, \hat{\pi} ,如果 DTVmax(π,π^)=maxsDTV(π,π^)=maxsaπ(s)π^(s)αD_{T V}^{\max }(\pi, \hat{\pi})=\max _{s} D_{T V}(\pi, \hat{\pi})=\max _{s} \sum_{a}\|\pi(s)-\hat{\pi}(s)\| \leq \alpha,那么

η(π^)η(π)Aπ(π^)4ϵα2γ(1γ)2ϵ=maxs,aAπ(s,a)\eta(\hat{\pi})-\eta(\pi) \geq A_{\pi}(\hat{\pi})-\frac{4 \epsilon \alpha^{2} \gamma}{(1-\gamma)^{2}} \\ \epsilon=\max _{s, a}\left|A_{\pi}(s, a)\right|

这一步的证明跟前作的引理一几乎一模一样,只是 ϵ\epsilon 的定义不同,所以不再复述。文中附录给除了第二种更简洁的证法,值得一看。

因为 DTV(π,π^)2DKL(ππ^)D_{T V}(\pi, \hat{\pi})^{2} \leq D_{K L}(\pi \| \hat{\pi}) 我们定义 C=4ϵγ/(1γ)2C=4 \epsilon \gamma /(1-\gamma)^{2} 那么:

η(π^)η(π)+Aπ(π^)CDKL(π^,π)\eta(\hat{\pi}) \geq \eta(\pi)+A_{\pi}(\hat{\pi})-C D_{K L}(\hat{\pi}, \pi)

也就是说,如果 DKL(π,π^)D_{KL}(\pi, \hat{\pi}) 足够小的话,我们提升不等式右边,那么不等式左边就保证会被提升。如此迭代,可使策略稳步提升。由此得出以下优化目标,即TRPO的精髓:

maxθη(πθold)+sd(s,πold)aπθ(s,a)Aπold(s,a) subject to DKL(ππold)δ\begin{aligned}& \max _{\theta} \eta\left(\pi_{\theta_{o l d}}\right)+\sum_{s} d\left(s, \pi_{old}\right) \sum_{a} \pi_{\theta}(s, a) A_{\pi_{o l d}}(s, a) \\\quad & \text { subject to } \quad D_{K L}\left(\pi \| \pi_{o l d}\right) \leq \delta\end{aligned}

DKL(ππold)D_{KL}(\pi||\pi_{old}) 这部分所限制的范围,就是所谓的Trust Region。


好,理论非常“优美”。实验结果如何呢?我只能说呵呵哒。具体实现点这里

这份代码主要是implement PPO,顺便实现了TRPO,而且实验结果狠狠打了PPO的脸:事实上效果的提升主要来自于各种trick,而非PPO声明的advantage clip。详情见作者的Paper

TRPO的实现主要在src/policy_gradients/steps.py:trpo_step, 我copy下来简要解释下:

def trpo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step):
    '''
    Trust Region Policy Optimization
    Runs K epochs of TRPO as in https://arxiv.org/abs/1502.05477
    Inputs:
    - all_states, the historical value of all the states
    - actions, the actions that the policy sampled
    - old_log_ps, the probability of the actions that the policy sampled
    - advs, advantages as estimated by GAE
    - net, policy network to train [WILL BE MUTATED]
    - params, additional placeholder for parameters like EPS
    Returns:
    - The TRPO loss; main job is to mutate the net
    '''    
    # Initial setup
    initial_parameters = flatten(net.parameters()).clone()
    pds = net(all_states)
    action_log_probs = net.get_loglikelihood(pds, actions)

    # Calculate losses
    surr_rew = surrogate_reward(advs, new=action_log_probs, old=old_log_ps).mean()
    grad = ch.autograd.grad(surr_rew, net.parameters(), retain_graph=True)
    flat_grad = flatten(grad)

    # Make fisher product estimator
    num_samples = int(all_states.shape[0] * params.FISHER_FRAC_SAMPLES)
    selected = np.random.choice(range(all_states.shape[0]), num_samples, replace=False)
    
    detached_selected_pds = select_prob_dists(pds, selected, detach=True)
    selected_pds = select_prob_dists(pds, selected, detach=False)
    
    kl = net.calc_kl(detached_selected_pds, selected_pds).mean()
    g = flatten(ch.autograd.grad(kl, net.parameters(), create_graph=True))
    def fisher_product(x, damp_coef=1.):
        contig_flat = lambda q: ch.cat([y.contiguous().view(-1) for y in q])
        z = g @ x
        hv = ch.autograd.grad(z, net.parameters(), retain_graph=True)
        return contig_flat(hv).detach() + x*params.DAMPING * damp_coef

    # Find KL constrained gradient step
    step = cg_solve(fisher_product, flat_grad, params.CG_STEPS)

    max_step_coeff = (2 * params.MAX_KL / (step @ fisher_product(step)))**(0.5)
    max_trpo_step = max_step_coeff * step

    if store and params.SHOULD_LOG_KL:
        kl_approximation_logging(all_states, pds, flat_grad, step, net, store)
        kl_vs_second_order_approx(all_states, pds, net, max_trpo_step, params, store, opt_step)

    # Backtracking line search
    with ch.no_grad():
        # Backtracking function
        def backtrack_fn(s):
            assign(initial_parameters + s.data, net.parameters())
            test_pds = net(all_states)
            test_action_log_probs = net.get_loglikelihood(test_pds, actions)
            new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean()
            if new_reward <= surr_rew or net.calc_kl(pds, test_pds).mean() > params.MAX_KL:
                return -float('inf')
            return new_reward - surr_rew
        expected_improve = flat_grad @ max_trpo_step
        final_step = backtracking_line_search(backtrack_fn, max_trpo_step,
                                              expected_improve,
                                              num_tries=params.MAX_BACKTRACK)

        assign(initial_parameters + final_step, net.parameters())

    return surr_rew

grad = ch.autograd.grad(surr_rew, net.parameters(), retain_graph=True)对应的是θη(πθ)\nabla_\theta \eta(\pi_\theta)

kl = net.calc_kl(detached_selected_pds, selected_pds).mean() 对应的是DKL(ππold)D_{KL}(\pi ||\pi_{old})

step = cg_solve(fisher_product, flat_grad, params.CG_STEPS) 这一步是比较难懂,这里涉及到一个叫共轭梯度的东西,感兴趣的同学可以移步这里

简单来说,我们需要用共轭梯度法求得natural policy gradient,参见TRPO的Section 7,这就涉及到论文笔记 Natural Policy Gradient 的内容。求policy optimisation本质上是个优化问题,而优化问题要回答两个问题:step size和step direction。

我们衡量两个policy的近似度,一般用的是KL Divergence而非欧式距离,所以在policy optimisation中,理论上natural gradient 要比 standard gradient提供的gradient更好。Natural gradient 被定义为^θη(πθ)=F1θη(πθ)\hat{\nabla}\theta \eta(\pi\theta) = F^{-1}\nabla_{\theta}\eta(\pi_{\theta}),这里FF是所谓的fisher information matrix,近似于KL divergence的二阶导数,natural graident的更新梯度:

d=2ϵgTFgF1gwhereg=θη(πθ)d = \sqrt{\frac{2\epsilon}{g^TFg}}F^{-1}g \quad\text{where}\quad g = \nabla_{\theta} \eta(\pi_\theta)

回到TRPO,cg_solve就是在用共轭梯度法求natural gradient。为什么要用共轭梯度法?因为θ\theta的维度太高,导致fisher information matrix求出来太慢且太大(N2N^2)。

求得step direction, 下一步就是求step size。具体在这个函数backtracking_line_search

它的作用是,不断缩小step size,直到 new_reward > surr_rew and net.calc_kl(pds, test_pds).mean() < params.MAX_KLη(πnew)>η(πold),DKL(πnewπold)<ϵ\eta(\pi_{new}) > \eta(\pi_{old}), D_{KL}(\pi_{new}||\pi_{old}) < \epsilon。好了,废话这么多,对应到代码中,其实核心就这么一段 net.calc_kl(pds, test_pds).mean() < params.MAX_K 就是TRPO的所谓的constraint。

那么,有意思的问题来了:如果我们不要这个constraint,实验结果会不会差很多呢?大家可以自行验证,哈哈哈哈哈哈哈。同学们,当你花了这么多时间和精力搞懂这么个吹的牛逼哄哄的东西,结果发现这特么是骗人的,你是打算继续在这个行当干下去接着骗别人,还是打算quit呢?