深度探索MADDPG算法及其改进策略理解
引子
深度强化学习可以分为两类:单智能体算法和多智能体算法,单智能体算法从DQN开始有policy gradient、actor critic、dpg、ppo、ddpg、sac等等,它们解决的是环境中存在一个智能体的情况(或者是多个智能体可以转化为一个智能体决策的情况),但是在某些环境(environment)下,似乎单智能体算法就有些心有余而力不足,例如足球比赛亦或是追逐游戏。如果依旧对每个agent采用单智能体算法会出现如下情况:在第 i i i个agent做出动作 a i a_i ai的情况下由于其余agent的动作 a j , j ≠ i a_j, j\ne i aj,j=i未知,会导致第 i i i个agent收到的奖励 r e w a r d reward reward不稳定,也就是对于单个agent来说,环境是不稳定(unstable)的。
从另一个方面来考虑,大多数DRL算法都沿用了其开山鼻祖DQN的replay buffer机制,即在一个合适的时机通过sample buffer得到无序的训练数据用以训练网络,但在一个不稳定的环境下可能出现下面情况:buffer中存在相同状态、相同动作下奖励reward却不同的数据,这会直接导致训练的震荡,甚至是崩溃。
多智能体深度强化学习的算法应运而生,这篇blog主要介绍一种DDPG算法的多智能体版本,即MADDPG算法(Multi-Agent Deep Deterministic Policy Gradient)。
核心思想
不想讲很多没用的。如果有DDPG基础的应该知道DDPG有4个网络:actor,target_actor,critic,target_critic。其中带target的网络是定期由其对应网络复制参数而来(可以全复制也可以采用soft update机制)。actor网络我们也可以叫它策略网络,负责为agent做决策,其输入为agent的状态state,critic网络我们一般称其为批评家,负责评估actor做这个决策的价值,用Movan的一句话就是critic像是坐在actor头上的指挥家一样,凭借其长远的目光教导actor做出越来越好的决策。因此critic网络的输入为状态加动作即state+action,用RL的通用术语讲,这个critic就是常听到的 Q Q Q函数。值得注意的是,单智能体DDPG的critic网络输入的仅仅是该agent的state和action,并不涉及其他agent的信息,因此直接使用单智能体算法在多智能体环境中会出现因无法获取够信息从而训练效果不好的问题。
MADDPG与DDPG的最大不同在于critic网络,你不是说信息获取得不够吗?行,我给你critic更多的信息,这里给出原论文的更新critic网络的公式加以说明:
L
(
θ
i
)
=
E
x
,
a
,
r
,
x
′
[
(
Q
i
μ
(
x
,
a
1
,
…
,
a
N
)
−
y
)
2
]
,
y
=
r
i
+
γ
Q
i
μ
′
(
x
′
,
a
1
′
,
…
,
a
N
′
)
∣
a
j
′
=
μ
j
′
(
o
j
)
\mathcal{L}\left(\theta_{i}\right)=\mathbb{E}_{\mathbf{x}, a, r, \mathbf{x}^{\prime}}\left[\left(Q_{i}^{\mu}\left(\mathbf{x}, a_{1}, \ldots, a_{N}\right)-y\right)^{2}\right], \quad y=r_{i}+\left.\gamma Q_{i}^{\boldsymbol{\mu}^{\prime}}\left(\mathbf{x}^{\prime}, a_{1}^{\prime}, \ldots, a_{N}^{\prime}\right)\right|_{a_{j}^{\prime}=\boldsymbol{\mu}_{j}^{\prime}\left(o_{j}\right)}
L(θi)=Ex,a,r,x′[(Qiμ(x,a1,…,aN)−y)2],y=ri+γQiμ′(x′,a1′,…,aN′)∣∣∣aj′=μj′(oj)
观察它
Q
Q
Q函数的输入为
x
,
a
1
,
…
,
a
N
\mathbf{x}, a_{1}, \ldots, a_{N}
x,a1,…,aN其中
a
1
,
…
,
a
N
a_{1}, \ldots, a_{N}
a1,…,aN很好理解,为其他所有agent的动作,
x
\mathbf{x}
x在原文中的描述如下:In the simplest case, x could consist of the observations of all agents, x = (o1, …, oN), however we could also include additional state information if available. 在一般的情况下
x
\mathbf{x}
x取所有agent的状态。讲到这里其实发现MADDPG看着复杂其实简单,就是critic从原来的只注重自身的经验到现在的注重全局经验。一句话概括:分布式的actor和集中式的critic。何谓分布式actor,即actor的输入只使用了其自身的状态state而没有其他agent状态的输入,在测试阶段只需要每个actor对每个agent做出指导即可。何谓集中式critic,critic集中了环境中的所有信息用以指导其actor。引用原文的一张图加以说明MADDPG的思想:
更新方式
critic的更新方式之前已经给出公式了,说白了就是TD式更新,与单智能体的DDPG如出一辙。
既然critic是对actor的“评委”,即critic输出的是对actor的“认可度”,那么actor的更新方向自然是向着让critic给自己打更高分的方向,即:
∇
θ
i
J
(
μ
i
)
=
E
x
,
a
∼
D
[
∇
θ
i
μ
i
(
a
i
∣
o
i
)
∇
a
i
Q
i
μ
(
x
,
a
1
,
…
,
a
N
)
∣
a
i
=
μ
i
(
o
i
)
]
,
\nabla_{\theta_{i}} J\left(\boldsymbol{\mu}_{i}\right)=\mathbb{E}_{\mathbf{x}, a \sim \mathcal{D}}\left[\left.\nabla_{\theta_{i}} \boldsymbol{\mu}_{i}\left(a_{i} \mid o_{i}\right) \nabla_{a_{i}} Q_{i}^{\mu}\left(\mathbf{x}, a_{1}, \ldots, a_{N}\right)\right|_{a_{i}=\boldsymbol{\mu}_{i}\left(o_{i}\right)}\right],
∇θiJ(μi)=Ex,a∼D[∇θiμi(ai∣oi)∇aiQiμ(x,a1,…,aN)∣ai=μi(oi)],
个人不太喜欢这种高深的公式表达,一句话说明白就是actor的loss是对应critic输出的
Q
Q
Q值加个负号(深度学习架构是最小化loss,即最大化
Q
Q
Q值,即最大化critic的打分)。
附上一段torch更新MADDPG网络的核心代码(可能没有上下文会难懂一点):
for agent_idx, (actor_c, actor_t, critic_c, critic_t, opt_a, opt_c) in \
enumerate(zip(actors_cur, actors_tar, critics_cur, critics_tar, optimizers_a, optimizers_c)):
_obs_n_o, _action_n, _rew_n, _obs_n_n, _done_n = memory.sample(
arglist.batch_size, agent_idx)
rew = torch.tensor(_rew_n, device=arglist.device, dtype=torch.float)
done_n = torch.tensor(_done_n, device=arglist.device, dtype=torch.float)
action_cur_o = torch.from_numpy(_action_n).to(arglist.device, torch.float)
obs_n_o = torch.from_numpy(_obs_n_o).to(arglist.device, torch.float)
obs_n_n = torch.from_numpy(_obs_n_n).to(arglist.device, torch.float)
action_tar = torch.cat([a_t(obs_n_n[:, obs_size[idx][0]:obs_size[idx][1]]).detach() \
for idx, a_t in enumerate(actors_tar)], dim=1)
q = critic_c(obs_n_o, action_cur_o).reshape(-1) # q
with torch.no_grad():
q_ = critic_t(obs_n_n, action_tar).reshape(-1) # q_
tar_value = q_ * arglist.gamma * (1 - done_n) + rew # q_*gamma*done + reward
loss_c = torch.nn.MSELoss()(q, tar_value) # bellman equation
opt_c.zero_grad()
loss_c.backward()
opt_c.step()
# --use the data to update the ACTOR
# There is no need to cal other agent's action
policy_c_new = actor_c(
obs_n_o[:, obs_size[agent_idx][0]:obs_size[agent_idx][1]])
# update the aciton of this agent
action_cur_o[:, action_size[agent_idx][0]:action_size[agent_idx][1]] = policy_c_new
loss_a = torch.mul(-1, torch.mean(critic_c(obs_n_o, action_cur_o)))
opt_a.zero_grad()
loss_a.backward()
opt_a.step()
伪代码
在具体实现细节上网上的代码都会存在些许出入,但整体框架按照上图来就是没错的。
优势和劣势
原文给出了一些实验结果:
推荐阅读
-
【摩尔线程+Colossal-AI强强联手】MusaBert登上CLUE榜单TOP10:技术细节揭秘 - 技术实力:摩尔线程凭借"软硬兼备"的技术底蕴,让MusaBert得以从底层优化到顶层。其内置多功能GPU配备AI加速和并行计算模块,提供了全面的AI与科学计算支持,为AI推理和低资源条件下的大模型训练等场景带来了高效、经济且环保的算力。 - 算法层面亮点:依托Colossal-AI AI大模型开发系统,MusaBert在训练过程中展现出了卓越的并行性能与易用性,特别在预处理阶段对DataLoader进行了优化,适应低资源环境高效处理海量数据。同时,通过精细的建模优化、领域内数据增强以及Adan优化器等手段,挖掘和展示了预训练语言模型出色的语义理解潜力。基于MusaBert,摩尔线程自主研发的MusaSim通过对比学习方法微调,结合百万对标注数据,MusaSim在多个任务如语义相似度、意图识别和情绪分析中均表现出色。 - 数据资源丰富:MusaBert除了自家高质量语义相似数据外,还融合了悟道开源200GB数据、CLUE社区80GB数据,以及浪潮公司提供的1TB高质量数据,保证模型即便在较小规模下仍具备良好性能。 当前,MusaBert已成功应用于摩尔线程的智能客服与数字人项目,并广泛服务于语义相似度、情绪识别、阅读理解与声韵识别等领域。为了降低大模型开发和应用难度,MusaBert及其相关高质量模型代码已在Colossal-AI仓库开源,可快速训练优质中文BERT模型。同时,通过摩尔线程与潞晨科技的深度合作,仅需一张多功能GPU单卡便能高效训练MusaBert或更大规模的GPT2模型,显著降低预训练成本,进一步推动双方在低资源大模型训练领域的共享目标。 MusaBert荣登CLUE榜单TOP10,象征着摩尔线程与潞晨科技联合研发团队在中文预训练研究领域的领先地位。展望未来,双方将携手探索更大规模的自然语言模型研究,充分运用上游数据资源,产出更为强大的模型并开源。持续强化在摩尔线程多功能GPU上的大模型训练能力,特别是在消费级显卡等低资源环境下,致力于降低使用大模型训练的门槛与成本,推动人工智能更加普惠。而潞晨科技作为重要合作伙伴,将继续发挥关键作用。
-
深度探索MADDPG算法及其改进策略理解