Published on

论文笔记 Approximately Optimal Approximate Reinforcement Learning

Authors

这篇提出一种Conservative Policy Iteration algorithm,可以在较小步数内,找到近似最优策略(Approximately Optimal Policy),这也是知名的Trust Region Policy Optimisation的前作。

文章第二部分给出一些RL的基本定义:

Vπ(s)=(1γ)Es,a[t=0γtR(st,at)π,s]Qπ(s,a)=(1γ)R(s,a)+γEs[Vπ(s)π,s]V_\pi(s) = (1 - \gamma)E_{s,a}[\sum_{t=0}^{\infty} \gamma^t R(s_t,a_t) | \pi, s] \\Q_\pi(s,a) = (1 - \gamma)R(s,a) + \gamma E_{s'}[V_\pi(s')|\pi,s]

这里加1γ1-\gamma是为了让Vπ(s),Qπ(s,a)[0,Rmax]V_\pi(s), Q_\pi(s,a) \in [0, R_{\max}],因为

Es,a[t=0γtR(s,a)π]Rmax+γRmax+=Rmax/(1γ)E_{s,a}[\sum_{t=0}^{\infty} \gamma^tR(s,a)|\pi] \leq R_{\max} + \gamma R_{\max} + \dots = R_{\max} / (1 - \gamma)

这样advantage function

Aπ(s,a)=Qπ(s,a)Vπ(s,a)[Rmax,Rmax]A_\pi(s,a) = Q_\pi(s,a) - V_\pi(s,a) \in [-R_{\max}, R_{\max}]

还给出了一个 discounted future state distribution 的定义

dπ,D=(1γ)t=0γtPr(st=tπ,D)d_{\pi, D} = (1 - \gamma) \sum_{t=0}^{\infty} \gamma^t Pr(s_t = t | \pi, D)

这样,给定一个 start state distribution DD, policy optimization 的目标函数

ηD(π)=EsD[Vπ(s)]=Es,adπ,π[R(s,a)]\eta_D (\pi) = E_{s \sim D} [V_\pi(s)] = E_{s,a \sim d_\pi, \pi} [R(s,a)]

我们展开的话,会发现:

τ=(s0,a0,s1,a1,),s0D,atπ(st),st+1Pr(ss=st,a=at)Pr(τ)=D(s0)t=0π(st,at)Pr(st+1st,at)R(τ)=(1γ)t=0γtR(st,at)η(π)=Eτ[R(τ)]\tau = (s_0, a_0, s_1, a_1, \dots),\quad s_0 \sim D, \quad a_t \sim \pi(s_t),\quad s_{t+1} \sim Pr(s'|s=s_t, a=a_t)\\ Pr(\tau) = D(s_0) \prod_{t=0}^{\infty} \pi(s_t,a_t)Pr(s_{t+1}|s_t, a_t)\\R(\tau) = (1-\gamma)\sum_{t=0}^{\infty}\gamma^tR(s_t,a_t) \\\eta(\pi) = E_{\tau}[R(\tau)]

由上很容易推出:

η(π)=sdπ,D(s)aπ(s,a)Q(s,a)\eta (\pi) = \sum_{s}d_{\pi,D}(s)\sum_{a}\pi(s,a)Q(s,a)

文章第三部分,主要提出policy optimization 面临的两大难题,Figure 3.1 是sparse reward的问题;Figure 3.2是flat plateau gradient的问题。我们讨论Figure 3.2的case。

该图所示MDP有 i,ji,j 两个states, initial state distribution 为 pp,initial policy为π\pi,具体如下

p(i)=0.8,p(j)=0.2π(i,1)=0.8,π(i,2)=0.2π(j,1)=0.9,π(j,2)=0.1R(i,1)=1,R(j,1)=2,R(i,2)=R(j,2)=0p(i) = 0.8, p(j) = 0.2 \\ \pi(i,1) = 0.8, \pi(i,2) = 0.2 \\ \pi(j,1) = 0.9, \pi(j,2) = 0.1 \\ R(i,1) = 1, R(j,1) = 2, R(i,2) = R(j,2) = 0

很显然, state jj 的self loop为最优解。

我们考虑一个parameterized policy function

πθ(s,a),π(as)=eθseθs,θRS×A\pi_\theta(s,a), \quad \pi(a|s) = \frac{e^{\theta_s}}{\sum e^{\theta_s}}, \quad \theta \in \mathbb{R}^{|S| \times |A|}

对应到此MDP,θ\theta2×22 \times 2 的矩阵,对目标函数求导:

θη(π)=s,ad(s)π(s,a)Qπ(s,a)\nabla_\theta \eta(\pi) = \sum_{s,a} d(s) \nabla\pi(s,a)Q_\pi(s,a)

我们这里肯定是希望增加 θi,2,θj,1\theta_{i,2}, \theta_{j,1},然而:

θi,2η=d(i)Q(i,2)π(i,2)(1π(i,2))d(i)Q(i,1)π(i,1)π(i,2)θj,1η=d(j)Q(j,1)π(j,1)(1π(j,1))d(j)Q(j,2)π(j,1)π(j,2)\nabla_{\theta_{i,2}} \eta = d(i) Q(i,2)\pi(i,2)(1-\pi(i,2)) - d(i) Q(i ,1) \pi(i, 1)\pi(i,2) \\ \nabla_{\theta_{j,1}} \eta = d(j) Q(j,1)\pi(j,1)(1-\pi(j,1)) - d(j) Q(j ,2) \pi(j, 1)\pi(j,2)

第一项,Q(i,1)Q(i,2)Q(i, 1) \gg Q(i,2) ,第二项,因为d(j)d(i)d(j) \ll d(i)

这样,policy gradient 非常小,学的就太慢了。本文就是为了解决这种问题而生。


文章考虑以下混合策略:

πnew=(1α)π+απ,α[0,1]π=π+π\pi_{new} = (1 - \alpha) \pi + \alpha \pi', \alpha \in [0, 1] \\\pi' = \pi + \nabla \pi

我们记policy advantage:

Aπ(π)=sd(s,π)aπ(s,a)A(s,a)A_\pi(\pi') = \sum_s d(s,\pi) \sum_a \pi'(s,a) A(s,a)

给出引理一如下:

Lemma 1.

η(πnew)η(π)αAπ(π)2α2γϵ1γ(1α)ϵ=11γmaxsaπ(s,a)Aπ(s,a)\eta(\pi_{new}) - \eta(\pi) \geq \alpha A_\pi(\pi') - \frac{2\alpha^2\gamma\epsilon}{1 - \gamma(1-\alpha)} \\ \epsilon = \frac{1}{1-\gamma} \max_s \sum_a \pi'(s,a) A_\pi(s,a)

证明以上引理,首先,

αη(πnew)=sd(s,πnew)a((1α)π+απ)Qπnew=sd(s,πnew)a(ππ)(Vπnew+Aπnew)\begin{aligned} \nabla_\alpha \eta (\pi_{new}) &= \sum_s d(s, \pi_{new}) \sum_a ((1 - \alpha) \pi + \alpha \pi') Q^{\pi_{new}} \\ &= \sum_s d(s, \pi_{new}) \sum_a (\pi' - \pi) (V^{\pi_{new}} + A^{\pi_{new}}) \end{aligned}

那么当α0,πnewπ\alpha \to 0, \pi_{new} \to \pi,

ηα(πnew)α=0=sd(s,π)a(Vπ+Aπ)(ππ)=sd(s,π)Vπa(ππ)+sd(s,π)aAπ(ππ)=sd(s,π)aAπ(ππ)=sd(s,π)aAππ=Aπ(π) \begin{aligned} \nabla \eta_\alpha (\pi_{new}) |_{\alpha=0} &= \sum_s d(s,\pi) \sum_a (V^\pi + A^\pi)(\pi' - \pi) \\ &= \sum_s d(s,\pi) V^\pi \sum_a (\pi' - \pi) + \sum_s d(s,\pi) \sum_a A^\pi(\pi' - \pi) \\ &= \sum_s d(s,\pi) \sum_a A^\pi(\pi' - \pi) \\ &= \sum_s d(s,\pi) \sum_a A^\pi\pi' \\ &= A_\pi(\pi') \end{aligned}\\

以上倒数第二,三是因为:

a(ππ)=aπaπ=11=0sd(s,π)aAππ=sd(s,π)a(QπVπ)π=sd(s,π)aπQπsd(s,π)Vπaπ=η(π)η(π)=0\sum_a (\pi - \pi') = \sum_a \pi - \sum_a \pi' = 1 - 1 = 0 \\ \begin{aligned} \sum_s d(s, \pi) \sum_a A^\pi\pi &= \sum_s d(s,\pi) \sum_a (Q^\pi - V^\pi)\pi \\ &= \sum_s d(s,\pi) \sum_a \pi Q^\pi - \sum_s d(s,\pi) V^\pi \sum_a \pi \\ &= \eta(\pi) - \eta(\pi) \\ &= 0 \end{aligned}

根据泰勒展开:

η(πnew)=η(π)+ααη(πnew)+O(α2)\eta(\pi_{new}) = \eta(\pi) + \alpha \nabla_\alpha \eta(\pi_{new}) + O(\alpha^2)

那么

η(πnew)η(π)=αAπ(π)+O(α2)\eta(\pi_{new}) - \eta(\pi) = \alpha A_\pi(\pi') + O(\alpha^2)

先hold一下,我们给出引理二:

Lemma 2.

η(π^)η(π)=sd(s,π^)aπ^(s,a)Aπ(s,a)\eta(\hat\pi) - \eta(\pi) = \sum_s d(s, \hat\pi) \sum_a \hat\pi(s,a) A^\pi(s,a)

证明:

η(π^s1)=Vπ^(s1)=Es,aπ^[t=1γt1R(st,at)s1]=t=1γt1Es,aπ^[R(st,at)+Vπ(st)Vπ(st)s1]=t=1γt1Es,aπ^[R(st,at)+γVπ(st+1)Vπ(st)s1]+Vπ(s1)=t=1γt1Es,aπ^[Qπ(st,at)Vπ(st)s1]+Vπ(s1)=sd(s,s1,π^)aπ^(s,a)Aπ(s,a)+Vπ(s1)\begin{aligned} \eta(\hat\pi|s_1) &= V^{\hat\pi}(s_1) \\ &= E_{s, a \sim \hat\pi} [\sum_{t=1}^\infty \gamma^{t-1}R(s_t, a_t)|s_1] \\ &= \sum_{t=1}^{\infty} \gamma^{t-1} E_{s,a \sim \hat\pi}[R(s_t, a_t) + V^\pi(s_t) - V^\pi(s_t) | s_1] \\ &= \sum_{t=1}^{\infty} \gamma^{t-1} E_{s,a \sim \hat{\pi}}[R(s_t, a_t) + \gamma V^\pi(s_{t+1}) - V^\pi(s_t) | s_1 ] + V^\pi(s_1) \\ &=\sum_{t=1}^{\infty} \gamma^{t-1} E_{s,a \sim \hat\pi}[Q^\pi(s_t, a_t) - V^\pi(s_t) | s_1 ] + V^\pi(s_1) \\ &= \sum_{s} d(s, s_1, \hat\pi)\sum_a \hat\pi(s,a)A^\pi(s,a) + V^\pi(s_1) \end{aligned}

倒数第三步是因为:

t=1γt1Es,aπ^[γVπ(st+1)s1]+Vπ(s1)=Es,aπ^[Vπ(s1)+γVπ(s2)+γ2Vπ(s3)+s1]=t=1γt1Es,aπ^[Vπ(st)s1]\begin{aligned} & \sum_{t=1}^{\infty} \gamma^{t-1} E_{s,a \sim \hat\pi}[ \gamma V^\pi(s_{t+1}) | s_1 ] + V^\pi(s_1) \\ =& E_{s,a \sim \hat\pi}[ V^\pi(s_1) + \gamma V^\pi(s_2) + \gamma^2 V^\pi(s_3) + \dots | s_1] \\ =& \sum_{t=1}^{\infty} \gamma^{t-1} E_{s,a \sim \hat\pi}[V^\pi(s_t)|s_1] \end{aligned}

因为initial state distribution 相等,所以

η(π^)=sd(s,π^)aπ^(s,a)Aπ(s,a)+η(π)η(π^)η(π)=sd(s,π^)aπ^(s,a)Aπ(s,a)\eta(\hat\pi) = \sum_s d(s, \hat\pi) \sum_a \hat\pi(s,a) A^\pi(s,a) + \eta(\pi) \\\eta(\hat\pi) - \eta(\pi) = \sum_s d(s, \hat\pi) \sum_a \hat\pi(s,a) A^\pi(s,a)

引理二证毕。我们回过头继续证明引理一,对比下发现:

η(πnew)η(π)=αAπ(π)+O(α2)η(πnew)η(π)=sd(s,πnew)aπnew(s,a)Aπ(s,a)\begin{aligned} \eta(\pi_{new}) - \eta(\pi) &= \alpha A_\pi(\pi') + O(\alpha^2) \\ \eta(\pi_{new}) - \eta(\pi) &= \sum_s d(s, \pi_{new}) \sum_a \pi_{new}(s,a) A^\pi(s,a) \end{aligned}

据此提示,我们来求余项:

η(πnew)η(π)=sd(s,πnew)aπnew(s,a)Aπ(s,a)=Esπnew[tγt1aπnewAπ]=Esπnew[tγt1aαπAπ]\begin{aligned} \eta(\pi_{new}) - \eta(\pi) &= \sum_s d(s, \pi_{new}) \sum_a \pi_{new}(s,a) A^\pi(s,a) \\ &= E_{s \sim \pi_{new}}[\sum_t \gamma^{t-1} \sum_a \pi_{new}A^\pi] \\ &= E_{s \sim \pi_{new}}[\sum_t \gamma^{t-1} \sum_a \alpha \pi' A^\pi] \end{aligned}

这里还是用了aπAπ=0\sum_a \pi A^\pi = 0 的性质。接下来我们这么想, πnew(π(s,a),π(s,a))\pi_{new} \sim (\pi(s,a), \pi'(s,a)) with probability (1α,α)(1- \alpha, \alpha),那么,在前tt步一直取 πnew=π\pi_{new} = \pi 的概率为1pt=(1α)t1 - p_t = (1-\alpha)^t,那么pt=1(1α)tp_t = 1 - (1 - \alpha)^t,我们记前tt步取π(s,a)\pi(s,a)的次数为ntn_t, 那么nt=0n_t = 0 意味着在前ttπnew=π\pi_{new} = \pi,所以

η(πnew)η(π)=Esπnew[tγt1aαπAπ]=αt(1pt)γt1Esπnew[aπAπnt=0]+αtptγt1Esπnew[aπAπnt>0]=αt(1pt)γt1Esπ[aπAπnt=0]+αtptγt1Esπnew[aπAπnt>0]=αtγt1Esπ[aπAπ]αtptγt1Esπ[aπAπnt=0]+αtptγt1Esπnew[aπAπnt>0]=αAπ(π)αtptγt1Esπ[aπAπnt=0]+αtptγt1Esπnew[aπAπnt>0]αAπ(π)2αtpt1γt1maxsaπAπ=αAπ(π)2αϵ(1γ)t(1(1α)t1)γt1=αAπ(π)2αϵ(1γ)(11γ11(1α)γ)=αAπ(π)2αϵαγ1(1α)γαAπ(π)2α2ϵ1γ\begin{aligned} & \eta(\pi_{new}) - \eta(\pi) \\ =& E_{s \sim \pi_{new}}[\sum_t \gamma^{t-1} \sum_a \alpha \pi' A^\pi] \\ =& \alpha\sum_t (1-p_t) \gamma^{t-1} E_{s\sim\pi_{new}}[\sum_a \pi'A^\pi | n_t = 0] + \alpha \sum_t p_t \gamma^{t-1} E_{s\sim \pi_{new}}[\sum_a \pi'A^\pi | n_t > 0] \\ =& \alpha \sum_t (1 - p_t) \gamma^{t-1} E_{s\sim \pi}[\sum_a \pi'A^\pi|n_t=0]+ \alpha \sum_t p_t \gamma^{t-1} E_{s\sim \pi_{new}}[\sum_a \pi'A^\pi | n_t > 0] \\ =& \alpha \sum_t \gamma^{t-1} E_{s\sim \pi}[\sum_a \pi'A^\pi] - \alpha\sum_t p_t \gamma^{t-1}E_{s\sim\pi}[\sum_a \pi'A^\pi|n_t=0] + \alpha\sum_t p_t \gamma^{t-1}E_{s\sim\pi_{new}}[\sum_a \pi'A^\pi|n_t>0] \\ =& \alpha A_\pi(\pi') - \alpha\sum_t p_t \gamma^{t-1}E_{s\sim\pi}[\sum_a \pi'A^\pi|n_t=0] + \alpha\sum_t p_t \gamma^{t-1}E_{s\sim\pi_{new}}[\sum_a \pi'A^\pi|n_t>0] \\ \geq& \alpha A_\pi(\pi') - 2 \alpha \sum_t p_{t-1} \gamma^{t-1} \max_s \sum_a \pi'A^\pi \\ =& \alpha A_\pi(\pi') - 2 \alpha \epsilon(1-\gamma) \sum_t(1 - (1 - \alpha)^{t-1})\gamma^{t-1} \\ =& \alpha A_\pi(\pi') - 2\alpha\epsilon(1-\gamma)(\frac{1}{1-\gamma} - \frac{1}{1 - (1-\alpha)\gamma}) \\ =& \alpha A_\pi(\pi') - 2\alpha\epsilon\frac{\alpha\gamma}{1-(1-\alpha)\gamma} \\ \geq & \alpha A_\pi(\pi') - \frac{2\alpha^2\epsilon}{1-\gamma} \end{aligned}\\

这里maxsaπAπ=ϵ(1γ)\max_s \sum_a \pi'A^\pi = \epsilon (1-\gamma),那么

η(πnew)η(π)αAπ(π)2αϵαγ1(1α)γ\eta(\pi_{new}) - \eta(\pi) \geq \alpha A_\pi(\pi') - 2\alpha\epsilon\frac{\alpha\gamma}{1-(1-\alpha)\gamma}

至此,引理一证毕。

回过头看,引理一说明了什么?它说明,如果我们用这种混合策略,那么我们就能保证了策略效果提升的下界。有了以上引理,我们接下来确定混合策略的参数α\alpha

我么知道

ϵ=maxsaπAπ1γR/(1γ)\epsilon =\frac{ \max_s \sum_a \pi'A^\pi }{1-\gamma}\leq R/(1-\gamma)

那么

αAπ(π)2α2ϵ1γ0α(1γ)2Aπ(π)/2R\alpha A_\pi(\pi') - \frac{2\alpha^2 \epsilon}{1 - \gamma} \geq 0 \Rightarrow \alpha \leq (1-\gamma)^2 A_\pi(\pi') / 2R

我们取

α=(1γ)2Aπ(π)/4R\alpha = (1-\gamma)^2 A_\pi(\pi') / 4R

即可确保

η(πnew)η(π)Aπ(π)2(1γ)28R\eta(\pi_{new}) - \eta(\pi) \geq \frac{ A_\pi (\pi')^2(1-\gamma)^2}{8R}

好了,有了以上理论基础,我们来看看作者提出的算法。假设我们有一个advantage function approximator

A^π(s,a)Aπ(s,a)\hat A^\pi(s,a) \approx A^\pi(s,a)

我们设

A^=sd(s,π)maxaA^(s,a)\hat A = \sum_s d(s, \pi) \max_a \hat A(s,a)

即可确保

(1γ)A^(1γ)maxπAπ(π)δ/3(1-\gamma) \hat A \geq (1-\gamma) \max_{\pi'} A_\pi(\pi') - \delta/3

如何确保?我们用NN来做advantage function approximator

fw(s,a)=A^πf_w(s,a) = \hat A^\pi
(1γ)sdπ(s)maxaAπ(s,a)fw(s,a)(1-\gamma)\sum_s d^\pi(s) \max_a |A^\pi(s,a) - f_w(s,a)|

以上loss可通过 π\pi 的采样获得,需要至少轨迹数为

R2ϵ2logR2ϵ2 \frac{R^2} {\epsilon^2} \log \frac{R^2} {\epsilon^2}

如果 (1γ)A^2δ/3(1-\gamma)\hat A \leq 2\delta/3 则停止更新策略,否则:

π(1α)π+αππ=argmaxafw(s,a)α=(A^δ3(1γ))(1γ)24R\pi \leftarrow (1-\alpha)\pi + \alpha \pi' \\\pi' = \arg\max_a f_w(s,a) \\ \alpha = (\hat A - \frac{\delta}{3(1-\gamma)})\frac{(1-\gamma)^2}{4R}

那么以上算法,可确保(1γ)A^2δ/3(1 - \gamma) \hat A \geq 2\delta/3,从而

Aπ(π)=A^δ3(1γ)δ3(1γ)A_\pi(\pi') = \hat A - \frac{\delta}{3(1-\gamma)} \geq \frac{\delta}{3(1-\gamma)}

因此

η(πnew)η(π)Aπ(π)2(1γ)28Rδ29(1γ)2(1γ)28R=δ272R\eta(\pi_{new}) - \eta(\pi) \geq \frac{ A_\pi(\pi')^2 (1-\gamma)^2 }{8R} \geq \frac{\delta^2}{9(1-\gamma)^2} \frac{(1-\gamma)^2}{8R} = \frac{\delta^2}{72R}

因为对于任意π:η(π)R/(1γ)\pi:\eta(\pi) \leq R/(1-\gamma),所以我们最多需要

R1γ72Rδ2\frac{R}{1-\gamma} \frac{72R}{\delta^2}

使得

maxaAπ(π)δ\max_a' A_\pi(\pi') \leq \delta

总结算法如下:

  1. 初始化 fw(s,a)f_w(s,a) ,随机策略π\pi ,确定 δ,γ\delta, \gamma
  2. 求出 A^=maxafw(s,a)\hat A = \max_a f_w(s,a) ,用 π\pi 采样,计算Aπ(s,a)A_\pi(s,a) ,更新fwf_w
  3. 如果(1γ)A^2δ/3(1-\gamma)\hat A \leq 2\delta/3 ,结束算法,返回策略π\pi
  4. 否则π=argmaxafw(s,a);π(1α)π+απ\pi' = \arg\max_a f_w(s,a);\pi \leftarrow (1-\alpha)\pi + \alpha \pi',回到第二步

以上就是Approximately Optimal Approximate RL的内容。