前言
- Unitree RL GYM 是一个开源的 基于 Unitree 机器人强化学习(Reinforcement Learning, RL)控制示例项目,用于训练、测试和部署四足机器人控制策略。该仓库支持多种 Unitree 机器人型号,包括 Go2、H1、H1_2 和 G1。仓库地址

- 本系列将着手解析整个仓库的核心代码与算法实现和训练教程。此系列默认读者拥有一定的强化学习基础和代码基础,故在部分原理和基础代码逻辑不做解释,对强化学习基础感兴趣的读者可以阅读我的入门系列:
- 第一期: 【浅显易懂理解强化学习】(一)Q-Learning原来是查表法-CSDN博客
- 第二期: 【浅显易懂理解强化学习】(二):Sarsa,保守派的胜利-CSDN博客
- 第三期:【浅显易懂理解强化学习】(三):DQN:当查表法装上大脑-CSDN博客
- 第四期:【浅显易懂理解强化学习】(四):Policy Gradients玩转策略采样-CSDN博客
- 第五期:【浅显易懂理解强化学习】(五):Actor-Critic与A3C,多线程的完全胜利-CSDN博客
- 第六期:【浅显易懂理解强化学习】(六):DDPG与TD3集百家之长-CSDN博客
- 第七期:【浅显易懂理解强化学习】(七):PPO,策略更新的安全阀-CSDN博客
- 阅读本系列的前置知识:
python语法,明白面向对象的封装pytorch基础使用- 神经网络基础知识
- 强化学习基础知识,至少了解
Policy Gradient、Actor-Critic和PPO
- 本系列:
- 本期将讲解
rsl_rl仓库的ActorCritic网络和ActorCriticRecurrent网络的python实现。
1 ActorCritic类
1-1 ActorCritic 网络回顾
- 我们在浅显易懂强化学习入门的第五期提到过
ActorCritic网络,这里快速的回顾一下核心的内容: - 在强化学习中,Actor-Critic 是一种同时学习策略(Policy)和价值函数(Value Function)的框架。公式上可以表达为: π θ ( a ∣ s ) ( 策略网络 ) \pi_\theta(a|s)(策略网络)πθ(a∣s)(策略网络) V ϕ ( s ) ( 价值网络 ) V_\phi(s)(价值网络)Vϕ(s)(价值网络)其中:
- Actor(策略网络):负责告诉你“在当前状态 s ss 下应该做什么动作 a aa” a t ∼ π θ ( a t ∣ s t ) a_t \sim \pi_\theta(a_t|s_t)at∼πθ(at∣st)
* 输出动作分布(连续动作通常是高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2)N(μ,σ2)
* 在训练中采样动作增加探索,推理时取均值动作减少随机性 - Critic(价值网络):负责告诉你“当前状态 s ss 的好坏”,即状态价值 V ϕ ( s t ) ≈ E [ R t ∣ s t ] V_\phi(s_t) \approx \mathbb{E}[R_t|s_t]Vϕ(st)≈E[Rt∣st]
* 用于计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t)At=Rt−Vϕ(st)
* 优势函数衡量某个动作比平均水平好多少,是 Actor 更新策略的参考
- Actor(策略网络):负责告诉你“在当前状态 s ss 下应该做什么动作 a aa” a t ∼ π θ ( a t ∣ s t ) a_t \sim \pi_\theta(a_t|s_t)at∼πθ(at∣st)
- 策略更新依赖价值
- Actor 的梯度来自策略梯度定理: ∇ θ J ( θ ) = E t [ A t ∇ θ log π θ ( a t ∣ s t ) ] \nabla_\theta J(\theta) = \mathbb{E}_t \big[ A_t \nabla_\theta \log \pi_\theta(a_t|s_t) \big]∇θJ(θ)=Et[At∇θlogπθ(at∣st)]
- Critic 的目标是最小化 均方误差: L critic = E t [ ( V ϕ ( s t ) − R t ) 2 ] L_\text{critic} = \mathbb{E}_t \big[(V_\phi(s_t) - R_t)^2\big]Lcritic=Et[(Vϕ(st)−Rt)2]
- 优势函数的作用
- 提高训练稳定性:仅更新比平均水平更好的动作
- 减少策略梯度的方差,使训练收敛更快
- 总结:
- Actor 决策,Critic 评估,优势函数桥接二者,使策略在训练中既有方向又有参照。
1-2 完整代码一览
- 我们打开
rsl_rl仓库
1git clone https://github.com/leggedrobotics/rsl_rl.git 2cd rsl_rl 3git checkout v1.0.2 4
- 在项目根目录下的
modules文件夹下可以找到actor_critic.py
1# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2# SPDX-License-Identifier: BSD-3-Clause 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are met: 6# 7# 1. Redistributions of source code must retain the above copyright notice, this 8# list of conditions and the following disclaimer. 9# 10# 2. Redistributions in binary form must reproduce the above copyright notice, 11# this list of conditions and the following disclaimer in the documentation 12# and/or other materials provided with the distribution. 13# 14# 3. Neither the name of the copyright holder nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28# 29# Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 31import numpy as np 32 33import torch 34import torch.nn as nn 35from torch.distributions import Normal 36from torch.nn.modules import rnn 37 38class ActorCritic(nn.Module): 39 is_recurrent = False 40 def __init__(self, num_actor_obs, 41 num_critic_obs, 42 num_actions, 43 actor_hidden_dims=[256, 256, 256], 44 critic_hidden_dims=[256, 256, 256], 45 activation='elu', 46 init_noise_std=1.0, 47 **kwargs): 48 if kwargs: 49 print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()])) 50 super(ActorCritic, self).__init__() 51 52 activation = get_activation(activation) 53 54 mlp_input_dim_a = num_actor_obs 55 mlp_input_dim_c = num_critic_obs 56 57 # Policy 58 actor_layers = [] 59 actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 60 actor_layers.append(activation) 61 for l in range(len(actor_hidden_dims)): 62 if l == len(actor_hidden_dims) - 1: 63 actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions)) 64 else: 65 actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1])) 66 actor_layers.append(activation) 67 self.actor = nn.Sequential(*actor_layers) 68 69 # Value function 70 critic_layers = [] 71 critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 72 critic_layers.append(activation) 73 for l in range(len(critic_hidden_dims)): 74 if l == len(critic_hidden_dims) - 1: 75 critic_layers.append(nn.Linear(critic_hidden_dims[l], 1)) 76 else: 77 critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1])) 78 critic_layers.append(activation) 79 self.critic = nn.Sequential(*critic_layers) 80 81 print(f"Actor MLP: {self.actor}") 82 print(f"Critic MLP: {self.critic}") 83 84 # Action noise 85 self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 86 self.distribution = None 87 # disable args validation for speedup 88 Normal.set_default_validate_args = False 89 90 # seems that we get better performance without init 91 # self.init_memory_weights(self.memory_a, 0.001, 0.) 92 # self.init_memory_weights(self.memory_c, 0.001, 0.) 93 94 @staticmethod 95 # not used at the moment 96 def init_weights(sequential, scales): 97 [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in 98 enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))] 99 100 101 def reset(self, dones=None): 102 pass 103 104 def forward(self): 105 raise NotImplementedError 106 107 @property 108 def action_mean(self): 109 return self.distribution.mean 110 111 @property 112 def action_std(self): 113 return self.distribution.stddev 114 115 @property 116 def entropy(self): 117 return self.distribution.entropy().sum(dim=-1) 118 119 def update_distribution(self, observations): 120 mean = self.actor(observations) 121 self.distribution = Normal(mean, mean*0. + self.std) 122 123 def act(self, observations, **kwargs): 124 self.update_distribution(observations) 125 return self.distribution.sample() 126 127 def get_actions_log_prob(self, actions): 128 return self.distribution.log_prob(actions).sum(dim=-1) 129 130 def act_inference(self, observations): 131 actions_mean = self.actor(observations) 132 return actions_mean 133 134 def evaluate(self, critic_observations, **kwargs): 135 value = self.critic(critic_observations) 136 return value 137 138def get_activation(act_name): 139 if act_name == "elu": 140 return nn.ELU() 141 elif act_name == "selu": 142 return nn.SELU() 143 elif act_name == "relu": 144 return nn.ReLU() 145 elif act_name == "crelu": 146 return nn.ReLU() 147 elif act_name == "lrelu": 148 return nn.LeakyReLU() 149 elif act_name == "tanh": 150 return nn.Tanh() 151 elif act_name == "sigmoid": 152 return nn.Sigmoid() 153 else: 154 print("invalid activation function!") 155 return None 156 157
- 我们接下来看每一个函数分别实现了什么
1-3 初始化函数
1class ActorCritic(nn.Module): 2 is_recurrent = False 3 def __init__(self, num_actor_obs, 4 num_critic_obs, 5 num_actions, 6 actor_hidden_dims=[256, 256, 256], 7 critic_hidden_dims=[256, 256, 256], 8 activation='elu', 9 init_noise_std=1.0, 10 **kwargs): 11 if kwargs: 12 print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()])) 13 super(ActorCritic, self).__init__() 14 15 activation = get_activation(activation) 16 17 mlp_input_dim_a = num_actor_obs 18 mlp_input_dim_c = num_critic_obs 19 20 # Policy 21 actor_layers = [] 22 actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 23 actor_layers.append(activation) 24 for l in range(len(actor_hidden_dims)): 25 if l == len(actor_hidden_dims) - 1: 26 actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions)) 27 else: 28 actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1])) 29 actor_layers.append(activation) 30 self.actor = nn.Sequential(*actor_layers) 31 32 # Value function 33 critic_layers = [] 34 critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 35 critic_layers.append(activation) 36 for l in range(len(critic_hidden_dims)): 37 if l == len(critic_hidden_dims) - 1: 38 critic_layers.append(nn.Linear(critic_hidden_dims[l], 1)) 39 else: 40 critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1])) 41 critic_layers.append(activation) 42 self.critic = nn.Sequential(*critic_layers) 43 44 print(f"Actor MLP: {self.actor}") 45 print(f"Critic MLP: {self.critic}") 46 47 # Action noise 48 self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 49 self.distribution = None 50 # disable args validation for speedup 51 Normal.set_default_validate_args = False 52 53 # seems that we get better performance without init 54 # self.init_memory_weights(self.memory_a, 0.001, 0.) 55 # self.init_memory_weights(self.memory_c, 0.001, 0.) 56
1-3-1 超参数
1def __init__(self, 2 num_actor_obs, 3 num_critic_obs, 4 num_actions, 5 actor_hidden_dims=[256, 256, 256], 6 critic_hidden_dims=[256, 256, 256], 7 activation='elu', 8 init_noise_std=1.0, 9 **kwargs): 10
- 我们来看看超参数
num_actor_obs:Actor 网络的输入维度,也就是策略网络可以看到的观测状态数量num_critic_obs:Critic 网络的输入维度,也就是价值网络看到的观测状态数量num_actions:动作空间维度actor_hidden_dims:Actor 网络每一隐藏层的神经元个数critic_hidden_dims:Critic 网络每一隐藏层神经元个数activation:隐藏层激活函数,对应get_activation()函数提供的几个激活函数,"elu","selu","relu","crelu","lrelu","tanh","sigmoid"init_noise_std:Actor 输出动作的初始噪声标准差 σ \sigmaσ
1-3-2 构建 Actor MLP与 Critic MLP
1# Policy 2actor_layers = [] 3actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 4actor_layers.append(activation) 5for l in range(len(actor_hidden_dims)): 6 if l == len(actor_hidden_dims) - 1: 7 actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions)) 8 else: 9 actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1])) 10 actor_layers.append(activation) 11self.actor = nn.Sequential(*actor_layers) 12 13# Value function 14critic_layers = [] 15critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 16critic_layers.append(activation) 17for l in range(len(critic_hidden_dims)): 18 if l == len(critic_hidden_dims) - 1: 19 critic_layers.append(nn.Linear(critic_hidden_dims[l], 1)) 20 else: 21 critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1])) 22 critic_layers.append(activation) 23self.critic = nn.Sequential(*critic_layers) 24 25print(f"Actor MLP: {self.actor}") 26print(f"Critic MLP: {self.critic}") 27
- 我们直接用一个表格来描述这两个网络
| 特征 | Actor 网络 | Critic 网络 |
|---|---|---|
| 输入 | num_actor_obs | num_critic_obs(可不同) |
| 输出 | 动作向量 μ θ ( s ) \mu_\theta(s)μθ(s) | 状态价值标量 V ϕ ( s ) V_\phi(s)Vϕ(s) |
| 输出层维度 | 动作空间维度 | 1 |
| 功能 | 选择动作(策略) | 评估状态(价值函数) |
| 使用 | 策略梯度更新 | 计算优势函数指导 Actor |
- 一个输出当前状态对应的动作向量 μ θ ( s ) \mu_\theta(s)μθ(s),一个输出评估当前状态的状态价值标量 V ϕ ( s ) V_\phi(s)Vϕ(s)
1-3-3 动作噪声初始化
1 # Action noise 2self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 3self.distribution = None 4Normal.set_default_validate_args = False 5
- 在连续动作空间中,动作采样服从正态分布: a ∼ N ( μ ( s ) , σ 2 ) a \sim \mathcal{N}(\mu(s), \sigma^2)a∼N(μ(s),σ2)
- 其中:
self.std:动作分布的标准差 σ \sigmaσ,控制动作探索self.distribution:动作分布对象Normal(mean, std),用于采样、计算 log_prob 和熵
1-4 损失函数计算辅助工具函数
1@property 2def action_mean(self): 3 return self.distribution.mean 4 5@property 6def action_std(self): 7 return self.distribution.stddev 8 9@property 10def entropy(self): 11 return self.distribution.entropy().sum(dim=-1) 12
- 这三个都是一些工具函数:
action_mean():返回当前状态下动作分布的均值 μ θ ( s ) \mu_\theta(s)μθ(s)action_std():返回当前状态下动作分布标准差 σ \sigmaσentropy():返回动作分布熵 H [ π θ ] H[\pi_\theta]H[πθ]- 上面这三个均在
algorithms/ppo.py计算损失函数的时候被调用
1mu_batch = self.actor_critic.action_mean 2sigma_batch = self.actor_critic.action_std 3entropy_batch = self.actor_critic.entropy 4
1-5 核心函数act()
1def update_distribution(self, observations): 2 mean = self.actor(observations) 3 self.distribution = Normal(mean, mean*0. + self.std) 4def act(self, observations, **kwargs): 5 self.update_distribution(observations) 6 return self.distribution.sample() 7
- 获取Actor 网络输出动作均值: μ θ ( s t ) = Actor ( s t ) \mu_\theta(s_t) = \text{Actor}(s_t)μθ(st)=Actor(st)
- 并和标准差组合成高斯动作分布: a t ∼ N ( μ θ ( s t ) , σ 2 I ) a_t \sim \mathcal{N}\big(\mu_\theta(s_t), \sigma^2 \mathbf{I}\big)at∼N(μθ(st),σ2I)
- 注意:这里
mean*0.只是生成与mean形状相同的零张量,保证广播正确。 algorithms/ppo.py中每一个batch训练的第一步就是Actor前向计算,重新计算 当前策略的动作分布
1self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) 2
1-6 get_actions_log_prob()
1 def get_actions_log_prob(self, actions): 2 return self.distribution.log_prob(actions).sum(dim=-1) 3
- 作用:计算动作的对数概率 log π θ ( a t ∣ s t ) \log \pi_\theta(a_t|s_t)logπθ(at∣st)
algorithms/ppo.py中直接调用:
1actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch) 2
1-7 核心函数evaluate
1def evaluate(self, critic_observations, **kwargs): 2 value = self.critic(critic_observations) 3 return value 4
- Critic 前向计算状态价值 V ϕ ( s ) V_\phi(s)Vϕ(s)
- 计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t)At=Rt−Vϕ(st)
algorithms/ppo.py中直接调用:
1value_batch = self.actor_critic.evaluate(critic_obs_batch) 2
1-8 网络结构
输入状态 s
MLP 隐藏层
Actor 输出 μ(s)
Critic 输出 V(s)
2 ActorCriticRecurrent
2-1 前置知识补充:RNN
- RNN 全称:Recurrent Neural Network(循环神经网络)
- 功能:处理序列数据或时间序列任务,例如自然语言、机器人控制、股价预测等
- 核心思想:当前时刻的输出不仅依赖当前输入,还依赖上一时刻的隐藏状态,形成网络内部的“记忆”。
- 数学公式: h t = f ( W x h x t + W h h h t − 1 + b h ) h_t = f(W_{xh} x_t + W_{hh} h_{t-1} + b_h)ht=f(Wxhxt+Whhht−1+bh)其中:
- x t x_txt:当前输入(比如观测向量 o t o_tot)
- h t − 1 h_{t-1}ht−1:上一时刻隐藏状态(记忆历史信息),其实就是是对历史信息的总结,相当于网络的“记忆”。
- h t h_tht:当前隐藏状态
- f ff:非线性激活函数(
tanh,ReLU等)
- 主要优势:
- 记忆历史信息:通过隐藏状态保留序列中重要的过去信息
- 处理部分可观测问题(POMDP):可以推断当前观测中缺失的状态信息
- 增强时间依赖建模能力:动作和价值可以依赖过去多步信息,而不仅仅是当前观测
2-2 RNN和经验池的区别
- **经验池(Replay Buffer)**是强化学习中 存储过去经验(state, action, reward, next_state)的容器。
- 它的作用是 离线采样历史经验进行训练,打破时间相关性,提高训练稳定性。
- 例子:在 DQN 或 PPO 中,经验池可以存储上千条轨迹,然后随机采样 mini-batch 更新网络参数。
| 特征 | RNN 记忆 | 经验池(Replay Buffer) |
|---|---|---|
| 数据类型 | 隐藏状态(隐藏的网络向量) | 完整的状态、动作、奖励轨迹 |
| 功能 | 捕捉时间依赖,生成连续决策 | 存储历史经验,用于训练网络 |
| 更新方式 | 随每个时间步前向传播动态更新 | 离线采样随机更新 |
| 生命周期 | 每条轨迹或 episode 内有效 | 可跨多个 episode 持久存在 |
- RNN 是网络内部的“记忆”,帮助做连续决策;经验池是训练中的“记忆库”,帮助算法学习更稳健的策略。
2-3 ActorCriticRecurrent
ActorCriticRecurrent是 ActorCritic 的 RNN 版本,引入循环神经网络使得 Actor 和 Critic 能够记住历史信息,适合处理 部分可观测环境(POMDP) 或时间依赖的任务。- 主要特点:
- Actor 和 Critic 都引入了“记忆模块”(RNN/LSTM/GRU),各自的职责仍不变:
- Actor 的记忆输出 → 决定动作分布( μ θ ( s t ) \mu_\theta(s_t)μθ(st))
- Critic 的记忆输出 → 估算当前状态的价值( V ϕ ( s t ) V_\phi(s_t)Vϕ(st))
- 网络会随着时间不断更新记忆
- 每个时间步:网络把当前观测 + 之前的记忆 → 更新隐藏状态 h t h_tht
- 说人话就是网络不断把新信息加入记忆,让它对整个动作序列有上下文感知。
- 每次新 episode 会清空记忆
- 用
reset(dones)把对应位置的隐藏状态清零 - 说人话就是每新的一轮会把记忆清空,不让上一局的经验干扰当前决策。
- 用
- 训练和推理用法不同
- 训练(批量更新):RNN 会处理整个序列,利用 masks 和保存的隐藏状态做前向传播
- 推理(收集经验):每步用上一时刻的隐藏状态生成动作
- Actor 和 Critic 都引入了“记忆模块”(RNN/LSTM/GRU),各自的职责仍不变:
- 说人话:RNN 给强化学习网络增加了“记忆”,让策略和价值函数能利用过去信息做决策,从而在复杂机器人控制或时间序列任务中表现更稳健。
2-4 完整代码一览
1# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2# SPDX-License-Identifier: BSD-3-Clause 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are met: 6# 7# 1. Redistributions of source code must retain the above copyright notice, this 8# list of conditions and the following disclaimer. 9# 10# 2. Redistributions in binary form must reproduce the above copyright notice, 11# this list of conditions and the following disclaimer in the documentation 12# and/or other materials provided with the distribution. 13# 14# 3. Neither the name of the copyright holder nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28# 29# Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 31import numpy as np 32 33import torch 34import torch.nn as nn 35from torch.distributions import Normal 36from torch.nn.modules import rnn 37from .actor_critic import ActorCritic, get_activation 38from rsl_rl.utils import unpad_trajectories 39 40class ActorCriticRecurrent(ActorCritic): 41 is_recurrent = True 42 def __init__(self, num_actor_obs, 43 num_critic_obs, 44 num_actions, 45 actor_hidden_dims=[256, 256, 256], 46 critic_hidden_dims=[256, 256, 256], 47 activation='elu', 48 rnn_type='lstm', 49 rnn_hidden_size=256, 50 rnn_num_layers=1, 51 init_noise_std=1.0, 52 **kwargs): 53 if kwargs: 54 print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),) 55 56 super().__init__(num_actor_obs=rnn_hidden_size, 57 num_critic_obs=rnn_hidden_size, 58 num_actions=num_actions, 59 actor_hidden_dims=actor_hidden_dims, 60 critic_hidden_dims=critic_hidden_dims, 61 activation=activation, 62 init_noise_std=init_noise_std) 63 64 activation = get_activation(activation) 65 66 self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 67 self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 68 69 print(f"Actor RNN: {self.memory_a}") 70 print(f"Critic RNN: {self.memory_c}") 71 72 def reset(self, dones=None): 73 self.memory_a.reset(dones) 74 self.memory_c.reset(dones) 75 76 def act(self, observations, masks=None, hidden_states=None): 77 input_a = self.memory_a(observations, masks, hidden_states) 78 return super().act(input_a.squeeze(0)) 79 80 def act_inference(self, observations): 81 input_a = self.memory_a(observations) 82 return super().act_inference(input_a.squeeze(0)) 83 84 def evaluate(self, critic_observations, masks=None, hidden_states=None): 85 input_c = self.memory_c(critic_observations, masks, hidden_states) 86 return super().evaluate(input_c.squeeze(0)) 87 88 def get_hidden_states(self): 89 return self.memory_a.hidden_states, self.memory_c.hidden_states 90 91 92class Memory(torch.nn.Module): 93 def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256): 94 super().__init__() 95 # RNN 96 rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM 97 self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 98 self.hidden_states = None 99 100 def forward(self, input, masks=None, hidden_states=None): 101 batch_mode = masks is not None 102 if batch_mode: 103 # batch mode (policy update): need saved hidden states 104 if hidden_states is None: 105 raise ValueError("Hidden states not passed to memory module during policy update") 106 out, _ = self.rnn(input, hidden_states) 107 out = unpad_trajectories(out, masks) 108 else: 109 # inference mode (collection): use hidden states of last step 110 out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 111 return out 112 113 def reset(self, dones=None): 114 # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state 115 for hidden_state in self.hidden_states: 116 hidden_state[..., dones, :] = 0.0 117
- 我们主要看的是和
ActorCritic不同的地方
2-5 Memory类
1class Memory(torch.nn.Module): 2 def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256): 3 super().__init__() 4 # RNN 5 rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM 6 self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 7 self.hidden_states = None 8 9 def forward(self, input, masks=None, hidden_states=None): 10 batch_mode = masks is not None 11 if batch_mode: 12 # batch mode (policy update): need saved hidden states 13 if hidden_states is None: 14 raise ValueError("Hidden states not passed to memory module during policy update") 15 out, _ = self.rnn(input, hidden_states) 16 out = unpad_trajectories(out, masks) 17 else: 18 # inference mode (collection): use hidden states of last step 19 out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 20 return out 21 22 def reset(self, dones=None): 23 # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state 24 for hidden_state in self.hidden_states: 25 hidden_state[..., dones, :] = 0.0 26
Memory是 封装 RNN 的模块,主要目的是:- 在 ActorCriticRecurrent 中 给 Actor 和 Critic 提供隐藏状态记忆
- 区分 训练(batch mode) 和 推理(inference mode) 的前向计算
- 管理 隐藏状态的初始化和重置
2-5-1 构造函数-继承自torch.nn.Module
1def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256): 2 super().__init__() 3 rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM 4 self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 5 self.hidden_states = None 6
input_size:RNN 每一步的输入维度(比如观测向量的维度)type:RNN 类型,支持'lstm'或'gru'num_layers:RNN 堆叠层数hidden_size:RNN 隐藏状态维度self.hidden_states:保存 RNN 的隐藏状态,用于推理模式下连续更新
2-5-2 前向函数 forward
1def forward(self, input, masks=None, hidden_states=None): 2 batch_mode = masks is not None 3 if batch_mode: 4 # batch mode (policy update) 5 if hidden_states is None: 6 raise ValueError("Hidden states not passed to memory module during policy update") 7 out, _ = self.rnn(input, hidden_states) 8 out = unpad_trajectories(out, masks) 9 else: 10 # inference mode (collection) 11 out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 12 return out 13
- batch_mode:当
masks不为空时,表示训练阶段(一次性处理一个 batch 的多条轨迹)- 需要传入
hidden_states(前一 batch 的隐藏状态) - RNN 处理序列 → 输出隐藏状态序列
unpad_trajectories:去掉填充的时间步,保证序列长度正确
- 需要传入
- inference mode:当
masks为空时,表示推理或收集经验- 输入是当前时间步的观测
input.unsqueeze(0):增加时间维度- 更新
self.hidden_states→ 下次继续使用 - 输出当前时间步的 RNN 输出
- 说人话:
- 训练模式 → 一次处理整个序列,batch 更新
- 推理模式 → 一步步走,用上一步的隐藏状态记忆历史
2-5-3 重置函数 `reset()
1def reset(self, dones=None): 2 for hidden_state in self.hidden_states: 3 hidden_state[..., dones, :] = 0.0 4
dones是布尔向量,表示哪些环境/样本 episode 已经结束- 将对应样本的隐藏状态置零
2-6 隐藏状态在训练模式的流程(批量训练)
- 训练一次通常会把一整个 batch(好几条轨迹)一次性送进网络。
- RNN 会处理整个序列:
- 每个时间步的输入:当前状态 + 上一时间步隐藏状态。
- 输出隐藏状态序列 h t h_tht。
- 去掉 padding(因为不同轨迹长度不一样,短的轨迹要填充)。
- 输出送到 MLP:
- Actor 得到动作分布。
- Critic 得到状态价值。
- 用这些输出计算策略梯度和价值函数损失,更新网络参数。
- 注意:隐藏状态在训练中是一次性处理整段序列,下一 batch 可以继续用上一次训练的隐藏状态。
2-7 隐藏状态在推理模式的流程(收集经验 / 实际控制)
- 每步只处理 当前状态。
- RNN 用 上一步的隐藏状态 h t − 1 h_{t-1}ht−1 来更新当前隐藏状态 h t h_tht。
- 送入 MLP:
- Actor 输出动作 μ ( s t ) μ(s_t)μ(st)。
- Critic 可选地输出状态价值 V ( s t ) V(s_t)V(st)。
- 执行动作,环境给奖励和下一个状态。
- reset(dones) 的作用:
- 某些环境已经结束了(done=True),对应的隐藏状态要清零。
- 避免上一轮记忆影响下一轮决策。
2-6 网络结构
输入状态 s
RNN / Memory
MLP 隐藏层
Actor 输出 μ(s)
Critic 输出 V(s)
隐藏状态 h_t

小结
- 本期主要解析了
rsl_rl仓库中 ActorCritic 与 ActorCriticRecurrent 的 Python 实现,回顾了 Actor-Critic 的核心原理,重点讲解了 ActorCriticRecurrent 引入 RNN/Memory 模块以增强网络对历史信息的记忆能力,区分了训练和推理模式下隐藏状态的处理,并对网络构建、动作采样、价值评估等函数实现进行了详细剖析,为理解复杂机器人控制任务中的策略与价值网络打下基础。 - 如有错误,欢迎指出!感谢观看
《【宇树机器人强化学习】(二):ActorCritic网络和ActorCriticRecurrent网络的python实现与解析》 是转载文章,点击查看原文。
