当前位置: 首页 > news >正文

Policy Gradient原理和Python实现

今天重温一下RL on-policy算法的始祖:Policy Gradient算法。第一节先讲原理。第二节讲Python代码实现。第三节讲Policy Gradient算法的引申思考。

一、Policy Gradient原理

要讲Policy Gradient算法,需要先简要介绍一下Markov奖励过程。整个RL算法的体系几乎都是建立在Markov奖励过程上的,我们需要先将这个过程用数学的语言建模。

Markov奖励过程可以简单的表示为这样一个序列过程(记为\(\tau\)):

\[s_0 \rightarrow a_0 \rightarrow r_0 \rightarrow s_1 \rightarrow a_1 \rightarrow r_1 \rightarrow s_2 \rightarrow ... \rightarrow r_N \rightarrow s_{N+1} \]

其中s为状态,a代表动作,r代表奖励。从状态\(s_0\)开始,智能体agent选择一个动作\(a_0\),获得一个奖励\(r_0\),然后转移到下一个状态\(s_1\),如此往复直到一个序列结束。

那么我们的目标是什么呢?我们是想建模一个agent(\(\pi_\theta\))可以帮我们做出动作决策,使得一条决策序列(\(\tau\))的累积奖励尽可能的大。当然这个决策是概率决策,状态的转移也是概率转移,所以序列是有随机性的。因此我们的目标是使得决策序列累积奖励的期望尽可能的大,表示为:

\[\max_{\theta} J(\theta) = \max_{\theta} E_{\tau \sim \pi_{\theta}}[R(\tau)] \]

一般地,序列的累积奖励表示为:

\[R(\tau) = \sum_{t=0}^{N} \gamma^{t} r_{t} \]

观察我们的目标函数\(J(\theta)\),我们需要用\(\pi_\theta\)来表示它,才能优化\(\theta\)参数。我们先做变形:

\[\begin{aligned} J(\theta) &= E_{\tau \sim \pi_{\theta}}[R(\tau)] \\ &= \int_{\tau} P(\tau|\pi_{\theta}) R(\tau) \end{aligned} \]

\(P(\tau|\pi_{\theta})\)代表的是在策略\(\pi_{\theta}\)下,轨迹\(\tau\)出现的概率。可以看到\(J(\theta)\)实际上是一个相当复杂的函数,包含一个积分,需要将\(\pi_{\theta}\)下所有可能的轨迹\(\tau\)都考虑进来,显然直接求出这个目标函数的表达式是不现实的。

现在,我们尝试对目标\(J(\theta)\)求导:

\[\begin{aligned} \nabla_{\theta} J(\theta) &= \nabla_{\theta}\int_{\tau} P(\tau|\pi_{\theta}) R(\tau) \\ &= \int_{\tau} \nabla_{\theta} P(\tau|\pi_{\theta}) R(\tau) \\ &= \int_{\tau} P(\tau|\pi_{\theta}) \nabla_{\theta} log P(\tau|\pi_{\theta}) R(\tau)\\ &= E_{\tau \sim \pi_{\theta}}[\nabla_{\theta} log P(\tau|\pi_{\theta}) R(\tau)] \end{aligned} \]

其中\(P(\tau|\pi_{\theta})\)可以进一步表示为:

\[\begin{aligned} P(\tau|\pi_{\theta}) &= \rho(s_0) \prod_{t=0}^{N} \pi_{\theta}(a_t|s_t)P(s_{t+1}|s_t, a_t) \\ \end{aligned} \]

其中\(\rho(s_0)\)\(P(s_{t+1}|s_t, a_t)\)对于\(\theta\)来讲都是常数项。因此我们可以得到最终化简的策略梯度公式:(更加详细的推导可以参考spinning up文档):

\[E_{\tau \sim \pi_\theta}[\sum_{t=0}^{N} \nabla_\theta \log \pi_\theta(a_t|s_t) R(\tau)] \]

公式中的期望可以用模特卡洛采样方法消去,于是我们得到了目标函数J(\theta)的一个随机梯度:

\[\hat{g} = \frac{1}{|D|} \sum_{\tau \in D} \sum_{t=0}^{N}\nabla_\theta \log \pi_\theta(a_t|s_t) R(\tau) \]

这里\(R(\tau)\)可以看作一个常数,在一次控制序列结束以后,可以直接计算出该序列的折扣奖励值。有了这个随机梯度我们就可以利用梯度下降法优化策略\(\pi_{\theta}\)

二、策略梯度的实现

策略梯度代码实现如下:

import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Boxdef mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):# Build a feedforward neural network.layers = []for j in range(len(sizes)-1):act = activation if j < len(sizes)-2 else output_activationlayers += [nn.Linear(sizes[j], sizes[j+1]), act()]return nn.Sequential(*layers)def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, epochs=50, batch_size=5000, render=False):# make environment, check spaces, get obs / act dimsenv = gym.make(env_name)assert isinstance(env.observation_space, Box), \"This example only works for envs with continuous state spaces."assert isinstance(env.action_space, Discrete), \"This example only works for envs with discrete action spaces."obs_dim = env.observation_space.shape[0]n_acts = env.action_space.n# make core of policy networklogits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])# make function to compute action distributiondef get_policy(obs):logits = logits_net(obs)return Categorical(logits=logits)# make action selection function (outputs int actions, sampled from policy)def get_action(obs):return get_policy(obs).sample().item()# make loss function whose gradient, for the right data, is policy gradientdef compute_loss(obs, act, weights):logp = get_policy(obs).log_prob(act)return -(logp * weights).mean()# make optimizeroptimizer = Adam(logits_net.parameters(), lr=lr)# for training policydef train_one_epoch():# make some empty lists for logging.batch_obs = []          # for observationsbatch_acts = []         # for actionsbatch_weights = []      # for R(tau) weighting in policy gradientbatch_rets = []         # for measuring episode returnsbatch_lens = []         # for measuring episode lengths# reset episode-specific variablesobs = env.reset()       # first obs comes from starting distributiondone = False            # signal from environment that episode is overep_rews = []            # list for rewards accrued throughout ep# render first episode of each epochfinished_rendering_this_epoch = False# collect experience by acting in the environment with current policywhile True:# renderingif (not finished_rendering_this_epoch) and render:env.render()# save obsbatch_obs.append(obs.copy())# act in the environmentact = get_action(torch.as_tensor(obs, dtype=torch.float32))obs, rew, done, _ = env.step(act)# save action, rewardbatch_acts.append(act)ep_rews.append(rew)if done:# if episode is over, record info about episodeep_ret, ep_len = sum(ep_rews), len(ep_rews)batch_rets.append(ep_ret)batch_lens.append(ep_len)# the weight for each logprob(a|s) is R(tau)batch_weights += [ep_ret] * ep_len# reset episode-specific variablesobs, done, ep_rews = env.reset(), False, []# won't render again this epochfinished_rendering_this_epoch = True# end experience loop if we have enough of itif len(batch_obs) > batch_size:break# take a single policy gradient update stepoptimizer.zero_grad()batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),act=torch.as_tensor(batch_acts, dtype=torch.int32),weights=torch.as_tensor(batch_weights, dtype=torch.float32))batch_loss.backward()optimizer.step()return batch_loss, batch_rets, batch_lens# training loopfor i in range(epochs):batch_loss, batch_rets, batch_lens = train_one_epoch()print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%(i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))if __name__ == '__main__':import argparseparser = argparse.ArgumentParser()parser.add_argument('--env_name', '--env', type=str, default='CartPole-v0')parser.add_argument('--render', action='store_true')parser.add_argument('--lr', type=float, default=1e-2)args = parser.parse_args()print('\nUsing simplest formulation of policy gradient.\n')train(env_name=args.env_name, render=args.render, lr=args.lr)

代码实现中有2个比较关键的位置:

  1. \(R(\tau)\)的计算
ep_ret = sum(ep_rews)
batch_weights += [ep_ret] * ep_len

可以看出,在最原始的policy gradient实现中,\(R(\tau)\)对于一条序列的所有s和a来说都是一个常数(就是最终的累积折扣奖励)

  1. Loss的计算
def compute_loss(obs, act, weights):logp = get_policy(obs).log_prob(act)return -(logp * weights).mean()

其中weights代表的是\(R(\tau)\)。可以看出,代码中实际计算的损失函数是:

\[\tilde{J(\theta)} = \frac{1}{|D|} \sum_{\tau \in D} \sum_{t=0}^{N} \log \pi_\theta(a_t|s_t) R(\tau) \]

一个需要强调的点是:这个损失函数和我们在第一节中推导的损失函数\(J(\theta)\)实际上并不相等(实际上可能是2个值相差很大的函数),只是恰好这两个函数在\(\theta\)处的梯度相等。因此我们可以用这个梯度优化策略\(\pi_{\theta}\)。这个损失函数一般被称为代理函数(surrogate function)

三、策略梯度的引申思考

虽然策略梯度的推导的过程比较简单,但是它是RL算法的基石,要深刻理解它,必须理解一个重要的内容:Surrogate Function

回忆策略梯度算法的推导过程:RL算法的目标是最大化目标函数\(J(\theta)\),于是我们尝试用策略\(\pi_\theta\)去表示\(J(\theta)\)。但是由于\(J(\theta)\)中有一个积分项和\(\pi_\theta\)有关,我们无法遍历所有的情况去求出这个积分。

所以我们转而去分析目标函数的导数——策略梯度。策略梯度可以用随机梯度的形式表达出来,有了这个梯度我们就知道目标函数在\(\theta\)处的优化方向了。但是我们为了能够进行梯度下降,我们需要构造一个surrogate function使得它的梯度就是目标梯度。而这个surrogate function本身的大小没有任何意义,这就是为什么我们总是说:在训练的时候,policy gradient的loss值没有参考意义

这个解决问题的思路似乎跟score function的思路有类似的地方。score function解决的是复杂分布的似然函数\(\log p(x)\)无法准确表示,转而求它的导数\(\nabla_\theta \log p(x)\)(score function)。而求出这个导数以后,可以通过Langevin Dynamics采样方法用score function实现在\(p(x)\)上的近似采样,也就解决了复杂分布的采样问题。通过这个方法得到的模型被称为score-based model,是生成模型的一种。这启发我们在目标不易求得的时候,我们可以分析其梯度的性质,梯度中可能存在解决问题的钥匙。

http://www.sczhlp.com/news/6861/

相关文章:

  • 记一些oi啸寄巧
  • 25.8.6模拟赛
  • 考前建议
  • RS232与RS485通信协议深度对比
  • Linux系统入门第四章 --磁盘管理和LVM
  • 部落冲突coc到5000杯后如何快速掉杯
  • [河南萌新联赛2025第(四)场]H (DP 图论)
  • WinForm 实现火绒杀毒界面
  • 【通信模型】Actors with Tokio
  • 开此侧门(25夏收集)
  • 《硅谷甄选》项目笔记
  • 腾讯游戏安全2023安卓初赛题解
  • Linux系统入门指南第二章 -- 安装及管理程序
  • 8.6总结
  • 20250806 HT-071
  • 博客园头像 - Charon
  • 【通信模型】你想知道的关于 actor 模型但可能不敢问的所有信息(译文)
  • 第二十三篇
  • ZROJ3265 猜数游戏
  • 安装 PVE
  • 003-存储读取数据的方案
  • 004-多线程
  • 安科瑞分布式光伏监控系统:筑牢10kV光伏电站的智能监测防线 - 实践
  • 什么是供应链 - 智慧园区
  • 泛微e8获取当前操作者并与明细行的项目负责人对比,不同就隐藏明细行
  • 25.8.5python模块
  • 在K8S中,同⼀个Pod的不同容器互相可以访问是怎么做到的?
  • 基于图像识别与分类的中国蛇类识别系统 - 教程
  • 对于项目调用方法的解析
  • 在K8S中,不同的Pod之间互相可以访问是怎么做到的?