【宇树机器人强化学习】(二):ActorCritic网络和ActorCriticRecurrent网络的python实现与解析

作者:zh路西法日期:2026/3/17

前言


1 ActorCritic类

1-1 ActorCritic 网络回顾
  • 我们在浅显易懂强化学习入门的第五期提到过ActorCritic网络,这里快速的回顾一下核心的内容:
  • 在强化学习中,Actor-Critic 是一种同时学习策略(Policy)和价值函数(Value Function)的框架。公式上可以表达为: π θ ( a ∣ s ) ( 策略网络 ) \pi_\theta(a|s)(策略网络)πθ​(a∣s)(策略网络) V ϕ ( s ) ( 价值网络 ) V_\phi(s)(价值网络)Vϕ​(s)(价值网络)其中:
    • Actor(策略网络):负责告诉你“在当前状态 s ss 下应该做什么动作 a aa” a t ∼ π θ ( a t ∣ s t ) a_t \sim \pi_\theta(a_t|s_t)at​∼πθ​(at​∣st​)
      * 输出动作分布(连续动作通常是高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2)N(μ,σ2)
      * 在训练中采样动作增加探索,推理时取均值动作减少随机性
    • Critic(价值网络):负责告诉你“当前状态 s ss 的好坏”,即状态价值 V ϕ ( s t ) ≈ E [ R t ∣ s t ] V_\phi(s_t) \approx \mathbb{E}[R_t|s_t]Vϕ​(st​)≈E[Rt​∣st​]
      * 用于计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t)At​=Rt​−Vϕ​(st​)
      * 优势函数衡量某个动作比平均水平好多少,是 Actor 更新策略的参考
  • 策略更新依赖价值
    • Actor 的梯度来自策略梯度定理: ∇ θ J ( θ ) = E t [ A t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \nabla_\theta J(\theta) = \mathbb{E}_t \big[ A_t \nabla_\theta \log \pi_\theta(a_t|s_t) \big]∇θ​J(θ)=Et​[At​∇θ​logπθ​(at​∣st​)]
    • Critic 的目标是最小化 均方误差: L critic = E t [ ( V ϕ ( s t ) − R t ) 2 ] L_\text{critic} = \mathbb{E}_t \big[(V_\phi(s_t) - R_t)^2\big]Lcritic​=Et​[(Vϕ​(st​)−Rt​)2]
  • 优势函数的作用
    • 提高训练稳定性:仅更新比平均水平更好的动作
    • 减少策略梯度的方差,使训练收敛更快
  • 总结:
    • Actor 决策,Critic 评估,优势函数桥接二者,使策略在训练中既有方向又有参照。

1-2 完整代码一览
  • 我们打开rsl_rl仓库
1git clone https://github.com/leggedrobotics/rsl_rl.git
2cd rsl_rl
3git checkout v1.0.2
4
  • 在项目根目录下的modules文件夹下可以找到actor_critic.py
1# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: BSD-3-Clause
3# 
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are met:
6#
7# 1. Redistributions of source code must retain the above copyright notice, this
8# list of conditions and the following disclaimer.
9#
10# 2. Redistributions in binary form must reproduce the above copyright notice,
11# this list of conditions and the following disclaimer in the documentation
12# and/or other materials provided with the distribution.
13#
14# 3. Neither the name of the copyright holder nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28#
29# Copyright (c) 2021 ETH Zurich, Nikita Rudin
30
31import numpy as np
32
33import torch
34import torch.nn as nn
35from torch.distributions import Normal
36from torch.nn.modules import rnn
37
38class ActorCritic(nn.Module):
39    is_recurrent = False
40    def __init__(self,  num_actor_obs,
41                        num_critic_obs,
42                        num_actions,
43                        actor_hidden_dims=[256, 256, 256],
44                        critic_hidden_dims=[256, 256, 256],
45                        activation='elu',
46                        init_noise_std=1.0,
47                        **kwargs):
48        if kwargs:
49            print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
50        super(ActorCritic, self).__init__()
51
52        activation = get_activation(activation)
53
54        mlp_input_dim_a = num_actor_obs
55        mlp_input_dim_c = num_critic_obs
56
57        # Policy
58        actor_layers = []
59        actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
60        actor_layers.append(activation)
61        for l in range(len(actor_hidden_dims)):
62            if l == len(actor_hidden_dims) - 1:
63                actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
64            else:
65                actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
66                actor_layers.append(activation)
67        self.actor = nn.Sequential(*actor_layers)
68
69        # Value function
70        critic_layers = []
71        critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
72        critic_layers.append(activation)
73        for l in range(len(critic_hidden_dims)):
74            if l == len(critic_hidden_dims) - 1:
75                critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
76            else:
77                critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
78                critic_layers.append(activation)
79        self.critic = nn.Sequential(*critic_layers)
80
81        print(f"Actor MLP: {self.actor}")
82        print(f"Critic MLP: {self.critic}")
83
84        # Action noise
85        self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
86        self.distribution = None
87        # disable args validation for speedup
88        Normal.set_default_validate_args = False
89        
90        # seems that we get better performance without init
91        # self.init_memory_weights(self.memory_a, 0.001, 0.)
92        # self.init_memory_weights(self.memory_c, 0.001, 0.)
93
94    @staticmethod
95    # not used at the moment
96    def init_weights(sequential, scales):
97        [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
98         enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]
99
100
101    def reset(self, dones=None):
102        pass
103
104    def forward(self):
105        raise NotImplementedError
106    
107    @property
108    def action_mean(self):
109        return self.distribution.mean
110
111    @property
112    def action_std(self):
113        return self.distribution.stddev
114    
115    @property
116    def entropy(self):
117        return self.distribution.entropy().sum(dim=-1)
118
119    def update_distribution(self, observations):
120        mean = self.actor(observations)
121        self.distribution = Normal(mean, mean*0. + self.std)
122
123    def act(self, observations, **kwargs):
124        self.update_distribution(observations)
125        return self.distribution.sample()
126    
127    def get_actions_log_prob(self, actions):
128        return self.distribution.log_prob(actions).sum(dim=-1)
129
130    def act_inference(self, observations):
131        actions_mean = self.actor(observations)
132        return actions_mean
133
134    def evaluate(self, critic_observations, **kwargs):
135        value = self.critic(critic_observations)
136        return value
137
138def get_activation(act_name):
139    if act_name == "elu":
140        return nn.ELU()
141    elif act_name == "selu":
142        return nn.SELU()
143    elif act_name == "relu":
144        return nn.ReLU()
145    elif act_name == "crelu":
146        return nn.ReLU()
147    elif act_name == "lrelu":
148        return nn.LeakyReLU()
149    elif act_name == "tanh":
150        return nn.Tanh()
151    elif act_name == "sigmoid":
152        return nn.Sigmoid()
153    else:
154        print("invalid activation function!")
155        return None
156
157
  • 我们接下来看每一个函数分别实现了什么

1-3 初始化函数
1class ActorCritic(nn.Module):
2    is_recurrent = False
3    def __init__(self,  num_actor_obs,
4                        num_critic_obs,
5                        num_actions,
6                        actor_hidden_dims=[256, 256, 256],
7                        critic_hidden_dims=[256, 256, 256],
8                        activation='elu',
9                        init_noise_std=1.0,
10                        **kwargs):
11        if kwargs:
12            print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
13        super(ActorCritic, self).__init__()
14
15        activation = get_activation(activation)
16
17        mlp_input_dim_a = num_actor_obs
18        mlp_input_dim_c = num_critic_obs
19
20        # Policy
21        actor_layers = []
22        actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
23        actor_layers.append(activation)
24        for l in range(len(actor_hidden_dims)):
25            if l == len(actor_hidden_dims) - 1:
26                actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
27            else:
28                actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
29                actor_layers.append(activation)
30        self.actor = nn.Sequential(*actor_layers)
31
32        # Value function
33        critic_layers = []
34        critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
35        critic_layers.append(activation)
36        for l in range(len(critic_hidden_dims)):
37            if l == len(critic_hidden_dims) - 1:
38                critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
39            else:
40                critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
41                critic_layers.append(activation)
42        self.critic = nn.Sequential(*critic_layers)
43
44        print(f"Actor MLP: {self.actor}")
45        print(f"Critic MLP: {self.critic}")
46
47        # Action noise
48        self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
49        self.distribution = None
50        # disable args validation for speedup
51        Normal.set_default_validate_args = False
52        
53        # seems that we get better performance without init
54        # self.init_memory_weights(self.memory_a, 0.001, 0.)
55        # self.init_memory_weights(self.memory_c, 0.001, 0.)
56

1-3-1 超参数
1def __init__(self,  
2             num_actor_obs,
3             num_critic_obs,
4             num_actions,
5             actor_hidden_dims=[256, 256, 256],
6             critic_hidden_dims=[256, 256, 256],
7             activation='elu',
8             init_noise_std=1.0,
9             **kwargs):
10
  • 我们来看看超参数
    • num_actor_obs:Actor 网络的输入维度,也就是策略网络可以看到的观测状态数量
    • num_critic_obs:Critic 网络的输入维度,也就是价值网络看到的观测状态数量
    • num_actions:动作空间维度
    • actor_hidden_dims:Actor 网络每一隐藏层的神经元个数
    • critic_hidden_dims:Critic 网络每一隐藏层神经元个数
    • activation:隐藏层激活函数,对应get_activation()函数提供的几个激活函数,"elu""selu""relu""crelu""lrelu""tanh""sigmoid"
    • init_noise_std:Actor 输出动作的初始噪声标准差 σ \sigmaσ

1-3-2 构建 Actor MLP与 Critic MLP
1# Policy
2actor_layers = []
3actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
4actor_layers.append(activation)
5for l in range(len(actor_hidden_dims)):
6	if l == len(actor_hidden_dims) - 1:
7		actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
8	else:
9		actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
10		actor_layers.append(activation)
11self.actor = nn.Sequential(*actor_layers)
12
13# Value function
14critic_layers = []
15critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
16critic_layers.append(activation)
17for l in range(len(critic_hidden_dims)):
18	if l == len(critic_hidden_dims) - 1:
19		critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
20	else:
21		critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
22		critic_layers.append(activation)
23self.critic = nn.Sequential(*critic_layers)
24
25print(f"Actor MLP: {self.actor}")
26print(f"Critic MLP: {self.critic}")
27
  • 我们直接用一个表格来描述这两个网络
特征Actor 网络Critic 网络
输入num_actor_obsnum_critic_obs(可不同)
输出动作向量 μ θ ( s ) \mu_\theta(s)μθ​(s)状态价值标量 V ϕ ( s ) V_\phi(s)Vϕ​(s)
输出层维度动作空间维度1
功能选择动作(策略)评估状态(价值函数)
使用策略梯度更新计算优势函数指导 Actor
  • 一个输出当前状态对应的动作向量 μ θ ( s ) \mu_\theta(s)μθ​(s),一个输出评估当前状态的状态价值标量 V ϕ ( s ) V_\phi(s)Vϕ​(s)

1-3-3 动作噪声初始化
1 # Action noise
2self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
3self.distribution = None
4Normal.set_default_validate_args = False
5
  • 在连续动作空间中,动作采样服从正态分布: a ∼ N ( μ ( s ) , σ 2 ) a \sim \mathcal{N}(\mu(s), \sigma^2)a∼N(μ(s),σ2)
  • 其中:
    • self.std :动作分布的标准差 σ \sigmaσ,控制动作探索
    • self.distribution:动作分布对象 Normal(mean, std),用于采样、计算 log_prob 和熵

1-4 损失函数计算辅助工具函数
1@property
2def action_mean(self):
3	return self.distribution.mean
4
5@property
6def action_std(self):
7	return self.distribution.stddev
8
9@property
10def entropy(self):
11	return self.distribution.entropy().sum(dim=-1)
12
  • 这三个都是一些工具函数:
  • action_mean():返回当前状态下动作分布的均值 μ θ ( s ) \mu_\theta(s)μθ​(s)
  • action_std():返回当前状态下动作分布标准差 σ \sigmaσ
  • entropy():返回动作分布熵 H [ π θ ] H[\pi_\theta]H[πθ​]
  • 上面这三个均在algorithms/ppo.py计算损失函数的时候被调用
1mu_batch = self.actor_critic.action_mean
2sigma_batch = self.actor_critic.action_std
3entropy_batch = self.actor_critic.entropy
4

1-5 核心函数act()
1def update_distribution(self, observations):
2	mean = self.actor(observations)
3	self.distribution = Normal(mean, mean*0. + self.std)
4def act(self, observations, **kwargs):
5        self.update_distribution(observations)
6        return self.distribution.sample()
7
  • 获取Actor 网络输出动作均值: μ θ ( s t ) = Actor ( s t ) \mu_\theta(s_t) = \text{Actor}(s_t)μθ​(st​)=Actor(st​)
  • 并和标准差组合成高斯动作分布: a t ∼ N ( μ θ ( s t ) , σ 2 I ) a_t \sim \mathcal{N}\big(\mu_\theta(s_t), \sigma^2 \mathbf{I}\big)at​∼N(μθ​(st​),σ2I)
  • 注意:这里 mean*0. 只是生成与 mean 形状相同的零张量,保证广播正确。
  • algorithms/ppo.py中每一个batch训练的第一步就是Actor前向计算,重新计算 当前策略的动作分布
1self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
2

1-6 get_actions_log_prob()
1 def get_actions_log_prob(self, actions):
2        return self.distribution.log_prob(actions).sum(dim=-1)
3
  • 作用:计算动作的对数概率 log ⁡ π θ ( a t ∣ s t ) \log \pi_\theta(a_t|s_t)logπθ​(at​∣st​)
  • algorithms/ppo.py中直接调用:
1actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
2

1-7 核心函数evaluate
1def evaluate(self, critic_observations, **kwargs):
2    value = self.critic(critic_observations)
3    return value
4
  • Critic 前向计算状态价值 V ϕ ( s ) V_\phi(s)Vϕ​(s)
    • 计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t)At​=Rt​−Vϕ​(st​)
  • algorithms/ppo.py中直接调用:
1value_batch = self.actor_critic.evaluate(critic_obs_batch)
2

1-8 网络结构

输入状态 s

MLP 隐藏层

Actor 输出 μ(s)

Critic 输出 V(s)


2 ActorCriticRecurrent

2-1 前置知识补充:RNN
  • RNN 全称Recurrent Neural Network(循环神经网络)
  • 功能:处理序列数据或时间序列任务,例如自然语言、机器人控制、股价预测等
  • 核心思想:当前时刻的输出不仅依赖当前输入,还依赖上一时刻的隐藏状态,形成网络内部的“记忆”。
  • 数学公式: h t = f ( W x h x t + W h h h t − 1 + b h ) h_t = f(W_{xh} x_t + W_{hh} h_{t-1} + b_h)ht​=f(Wxh​xt​+Whh​ht−1​+bh​)其中:
    • x t x_txt​​:当前输入(比如观测向量 o t o_tot​​)
    • h t − 1 h_{t-1}ht−1​​:上一时刻隐藏状态(记忆历史信息),其实就是是对历史信息的总结,相当于网络的“记忆”。
    • h t h_tht​​:当前隐藏状态
    • f ff:非线性激活函数(tanh, ReLU等)
  • 主要优势
    1. 记忆历史信息:通过隐藏状态保留序列中重要的过去信息
    2. 处理部分可观测问题(POMDP):可以推断当前观测中缺失的状态信息
    3. 增强时间依赖建模能力:动作和价值可以依赖过去多步信息,而不仅仅是当前观测

2-2 RNN和经验池的区别
  • **经验池(Replay Buffer)**是强化学习中 存储过去经验(state, action, reward, next_state)的容器
  • 它的作用是 离线采样历史经验进行训练,打破时间相关性,提高训练稳定性。
  • 例子:在 DQN 或 PPO 中,经验池可以存储上千条轨迹,然后随机采样 mini-batch 更新网络参数。
特征RNN 记忆经验池(Replay Buffer)
数据类型隐藏状态(隐藏的网络向量)完整的状态、动作、奖励轨迹
功能捕捉时间依赖,生成连续决策存储历史经验,用于训练网络
更新方式随每个时间步前向传播动态更新离线采样随机更新
生命周期每条轨迹或 episode 内有效可跨多个 episode 持久存在
  • RNN 是网络内部的“记忆”,帮助做连续决策;经验池是训练中的“记忆库”,帮助算法学习更稳健的策略

2-3 ActorCriticRecurrent
  • ActorCriticRecurrentActorCritic 的 RNN 版本,引入循环神经网络使得 Actor 和 Critic 能够记住历史信息,适合处理 部分可观测环境(POMDP) 或时间依赖的任务。
  • 主要特点:
    1. Actor 和 Critic 都引入了“记忆模块”(RNN/LSTM/GRU),各自的职责仍不变:
      • Actor 的记忆输出 → 决定动作分布( μ θ ( s t ) \mu_\theta(s_t)μθ​(st​))
      • Critic 的记忆输出 → 估算当前状态的价值( V ϕ ( s t ) V_\phi(s_t)Vϕ​(st​))
    2. 网络会随着时间不断更新记忆
      • 每个时间步:网络把当前观测 + 之前的记忆 → 更新隐藏状态 h t h_tht​
      • 说人话就是网络不断把新信息加入记忆,让它对整个动作序列有上下文感知。
    3. 每次新 episode 会清空记忆
      • reset(dones) 把对应位置的隐藏状态清零
      • 说人话就是每新的一轮会把记忆清空,不让上一局的经验干扰当前决策。
    4. 训练和推理用法不同
      • 训练(批量更新):RNN 会处理整个序列,利用 masks 和保存的隐藏状态做前向传播
      • 推理(收集经验):每步用上一时刻的隐藏状态生成动作
  • 说人话:RNN 给强化学习网络增加了“记忆”,让策略和价值函数能利用过去信息做决策,从而在复杂机器人控制或时间序列任务中表现更稳健。

2-4 完整代码一览
1# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: BSD-3-Clause
3# 
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are met:
6#
7# 1. Redistributions of source code must retain the above copyright notice, this
8# list of conditions and the following disclaimer.
9#
10# 2. Redistributions in binary form must reproduce the above copyright notice,
11# this list of conditions and the following disclaimer in the documentation
12# and/or other materials provided with the distribution.
13#
14# 3. Neither the name of the copyright holder nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28#
29# Copyright (c) 2021 ETH Zurich, Nikita Rudin
30
31import numpy as np
32
33import torch
34import torch.nn as nn
35from torch.distributions import Normal
36from torch.nn.modules import rnn
37from .actor_critic import ActorCritic, get_activation
38from rsl_rl.utils import unpad_trajectories
39
40class ActorCriticRecurrent(ActorCritic):
41    is_recurrent = True
42    def __init__(self,  num_actor_obs,
43                        num_critic_obs,
44                        num_actions,
45                        actor_hidden_dims=[256, 256, 256],
46                        critic_hidden_dims=[256, 256, 256],
47                        activation='elu',
48                        rnn_type='lstm',
49                        rnn_hidden_size=256,
50                        rnn_num_layers=1,
51                        init_noise_std=1.0,
52                        **kwargs):
53        if kwargs:
54            print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),)
55
56        super().__init__(num_actor_obs=rnn_hidden_size,
57                         num_critic_obs=rnn_hidden_size,
58                         num_actions=num_actions,
59                         actor_hidden_dims=actor_hidden_dims,
60                         critic_hidden_dims=critic_hidden_dims,
61                         activation=activation,
62                         init_noise_std=init_noise_std)
63
64        activation = get_activation(activation)
65
66        self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
67        self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
68
69        print(f"Actor RNN: {self.memory_a}")
70        print(f"Critic RNN: {self.memory_c}")
71
72    def reset(self, dones=None):
73        self.memory_a.reset(dones)
74        self.memory_c.reset(dones)
75
76    def act(self, observations, masks=None, hidden_states=None):
77        input_a = self.memory_a(observations, masks, hidden_states)
78        return super().act(input_a.squeeze(0))
79
80    def act_inference(self, observations):
81        input_a = self.memory_a(observations)
82        return super().act_inference(input_a.squeeze(0))
83
84    def evaluate(self, critic_observations, masks=None, hidden_states=None):
85        input_c = self.memory_c(critic_observations, masks, hidden_states)
86        return super().evaluate(input_c.squeeze(0))
87    
88    def get_hidden_states(self):
89        return self.memory_a.hidden_states, self.memory_c.hidden_states
90
91
92class Memory(torch.nn.Module):
93    def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
94        super().__init__()
95        # RNN
96        rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
97        self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
98        self.hidden_states = None
99    
100    def forward(self, input, masks=None, hidden_states=None):
101        batch_mode = masks is not None
102        if batch_mode:
103            # batch mode (policy update): need saved hidden states
104            if hidden_states is None:
105                raise ValueError("Hidden states not passed to memory module during policy update")
106            out, _ = self.rnn(input, hidden_states)
107            out = unpad_trajectories(out, masks)
108        else:
109            # inference mode (collection): use hidden states of last step
110            out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
111        return out
112
113    def reset(self, dones=None):
114        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
115        for hidden_state in self.hidden_states:
116            hidden_state[..., dones, :] = 0.0
117
  • 我们主要看的是和ActorCritic不同的地方

2-5 Memory类
1class Memory(torch.nn.Module):
2    def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
3        super().__init__()
4        # RNN
5        rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
6        self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
7        self.hidden_states = None
8    
9    def forward(self, input, masks=None, hidden_states=None):
10        batch_mode = masks is not None
11        if batch_mode:
12            # batch mode (policy update): need saved hidden states
13            if hidden_states is None:
14                raise ValueError("Hidden states not passed to memory module during policy update")
15            out, _ = self.rnn(input, hidden_states)
16            out = unpad_trajectories(out, masks)
17        else:
18            # inference mode (collection): use hidden states of last step
19            out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
20        return out
21
22    def reset(self, dones=None):
23        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
24        for hidden_state in self.hidden_states:
25            hidden_state[..., dones, :] = 0.0
26
  • Memory封装 RNN 的模块,主要目的是:
    1. 在 ActorCriticRecurrent 中 给 Actor 和 Critic 提供隐藏状态记忆
    2. 区分 训练(batch mode)推理(inference mode) 的前向计算
    3. 管理 隐藏状态的初始化和重置
2-5-1 构造函数-继承自torch.nn.Module
1def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
2    super().__init__()
3    rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
4    self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
5    self.hidden_states = None
6
  • input_size:RNN 每一步的输入维度(比如观测向量的维度)
  • type:RNN 类型,支持 'lstm''gru'
  • num_layers:RNN 堆叠层数
  • hidden_size:RNN 隐藏状态维度
  • self.hidden_states:保存 RNN 的隐藏状态,用于推理模式下连续更新

2-5-2 前向函数 forward
1def forward(self, input, masks=None, hidden_states=None):
2    batch_mode = masks is not None
3    if batch_mode:
4        # batch mode (policy update)
5        if hidden_states is None:
6            raise ValueError("Hidden states not passed to memory module during policy update")
7        out, _ = self.rnn(input, hidden_states)
8        out = unpad_trajectories(out, masks)
9    else:
10        # inference mode (collection)
11        out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
12    return out
13
  • batch_mode:当 masks 不为空时,表示训练阶段(一次性处理一个 batch 的多条轨迹)
    • 需要传入 hidden_states(前一 batch 的隐藏状态)
    • RNN 处理序列 → 输出隐藏状态序列
    • unpad_trajectories:去掉填充的时间步,保证序列长度正确
  • inference mode:当 masks 为空时,表示推理或收集经验
    • 输入是当前时间步的观测
    • input.unsqueeze(0):增加时间维度
    • 更新 self.hidden_states → 下次继续使用
    • 输出当前时间步的 RNN 输出
  • 说人话:
    • 训练模式 → 一次处理整个序列,batch 更新
    • 推理模式 → 一步步走,用上一步的隐藏状态记忆历史

2-5-3 重置函数 `reset()
1def reset(self, dones=None):
2    for hidden_state in self.hidden_states:
3        hidden_state[..., dones, :] = 0.0
4
  • dones 是布尔向量,表示哪些环境/样本 episode 已经结束
  • 将对应样本的隐藏状态置零

2-6 隐藏状态在训练模式的流程(批量训练)
  1. 训练一次通常会把一整个 batch(好几条轨迹)一次性送进网络。
  2. RNN 会处理整个序列:
    • 每个时间步的输入:当前状态 + 上一时间步隐藏状态。
    • 输出隐藏状态序列 h t h_tht​。
  3. 去掉 padding(因为不同轨迹长度不一样,短的轨迹要填充)。
  4. 输出送到 MLP:
    • Actor 得到动作分布。
    • Critic 得到状态价值。
  5. 用这些输出计算策略梯度和价值函数损失,更新网络参数。
  6. 注意:隐藏状态在训练中是一次性处理整段序列,下一 batch 可以继续用上一次训练的隐藏状态。

2-7 隐藏状态在推理模式的流程(收集经验 / 实际控制)
  1. 每步只处理 当前状态
  2. RNN 用 上一步的隐藏状态 h t − 1 h_{t-1}ht−1​ 来更新当前隐藏状态 h t h_tht​。
  3. 送入 MLP:
    • Actor 输出动作 μ ( s t ) μ(s_t)μ(st​)。
    • Critic 可选地输出状态价值 V ( s t ) V(s_t)V(st​)。
  4. 执行动作,环境给奖励和下一个状态。
  5. reset(dones) 的作用
    • 某些环境已经结束了(done=True),对应的隐藏状态要清零。
    • 避免上一轮记忆影响下一轮决策。

2-6 网络结构

输入状态 s

RNN / Memory

MLP 隐藏层

Actor 输出 μ(s)

Critic 输出 V(s)

隐藏状态 h_t


在这里插入图片描述


小结

  • 本期主要解析了 rsl_rl 仓库中 ActorCriticActorCriticRecurrent 的 Python 实现,回顾了 Actor-Critic 的核心原理,重点讲解了 ActorCriticRecurrent 引入 RNN/Memory 模块以增强网络对历史信息的记忆能力,区分了训练和推理模式下隐藏状态的处理,并对网络构建、动作采样、价值评估等函数实现进行了详细剖析,为理解复杂机器人控制任务中的策略与价值网络打下基础。
  • 如有错误,欢迎指出!感谢观看

【宇树机器人强化学习】(二):ActorCritic网络和ActorCriticRecurrent网络的python实现与解析》 是转载文章,点击查看原文


相关推荐


Git Worktree / Worktrunk:并行 AI 开发工作流实战
RickeyBoy2026/3/9

最近在日常开发中尝试了用 Git Worktree (Worktrunk) 配合 Claude Code 进行并行开发,体验下来效果非常好。这篇文章就来分享一下这套工作流的搭建和使用经验,希望能对大家有点帮助~ 欢迎大家点个 star:Github 以及下载我的独立 app: iColors 一、为什么需要 Git Worktree 先说一个日常开发中很常见的场景:你正在开发一个新功能,突然来了一个紧急 bug 需要修复。通常你要么 git stash,要么 git commit 一个半成品


在OrangePi-5 Plus/5 Ultra上实时运行yolo26进行无人机检测,fps超50!
吃素的力2026/3/1

在OrangePi-5 Plus/5 Ultra上使用VideoPipe与YOLO26n实现高性能无人机检测 视频效果展示 RK3588无人机检测 前言 随着低空经济的快速发展,无人机检测已成为安防监控、边境巡逻、关键区域保护等场景中的重要需求。OrangePi 5 Plus和OrangePi 5 Ultra作为瑞芯微RK3588平台的高性能开发板,凭借其强大的NPU算力,成为边缘端AI推理的理想选择。 本文将详细介绍如何基于VideoPipe框架,结合最新的Y


花 200 美刀买“黑盒”?Claude Code 这波更新,把程序员当傻子了吧…
Dcs2026/2/21

有些产品吧,功能再强,只要开始“藏事儿”,程序员的雷达立马就响了: 你到底读了哪个文件?你到底搜了啥?你到底改了啥?——别跟我说“别管细节,反正我很聪明”。哥们,工程不是玄学,是可验证、可追溯、可复盘。 然后,Claude Code 2.1.20 就真把这事做了:把“读取文件路径”和“搜索 pattern”这种最基础的可观测信息,直接干没了。 1)更新前 vs 更新后:从“可审计”变成“随缘”🤡 以前你会看到它读了哪些文件、搜了什么关键词,属于那种一眼就能判断它有没有跑偏的“低噪音透明输出”


为什么 Memo Code 先做 CLI:以及终端输入框到底有多难搞
mCell2026/2/12

同步至个人站点:为什么 Memo Code 先做 CLI:以及终端输入框到底有多难搞 如果你对我的 Code Agent项目感兴趣,可以看这里: Github Repo: Memo Code - Github 站点:Memo Web Site 大概四年前,我刚接触编程。学的是 C 语言,第一个程序当然是 hello world。 很简单,几行就写完。run 一下,弹出来一个 terminal(我已经忘了当时用的是什么:cmd?PowerShell?反正不重要),然后打印了一行: “hell


Java8 API文档搜索引擎_优化构建索引速度
_周游2026/2/3

本专栏前文已介绍完成索引模块程序: https://blog.csdn.net/m0_63299495/article/details/157515700?spm=1011.2415.3001.5331https://blog.csdn.net/m0_63299495/article/details/157515700?spm=1011.2415.3001.5331并对关键部分进行了细节整理: https://blog.csdn.net/m0_63299495/article/details


Linux软件安装 —— Flink集群安装(集成Zookeeper、Hadoop高可用)
吱唔猪~2026/1/25

文章目录 一、节点说明二、配置节点间免密登录三、JDK安装四、Zookeeper安装五、Hadoop安装六、Flink安装1、基础环境准备(1)下载安装包(2)上传并解压 2、修改配置(1)配置zookeeper(2)配置flink-conf.yaml(3)配置workers(4)创建必要的目录(5)配置环境变量 3、分发flink 七、集群测试1、启动zookeeper,hadoop2、Yarn Session测试(1)模式介绍(2)准备测试资源


图解DeepSeek最新论文,人人都能看得懂!
饼干哥哥2026/1/16

DeepSeek 又发论文了。 这一次,没有惊天动地的参数军备竞赛,没有万卡集群的暴力美学。 他们只是冷静地指出了当前 AI 届一个“皇帝的新衣”: 我们最顶尖的大模型,其实都在做着极其愚蠢的事情。 在这篇名为《Conditional Memory via Scalable Lookup》(基于可扩展查找的条件记忆)的论文中,DeepSeek 创始人梁文锋亲自署名,揭示了下一代大模型架构(V4?)的核心秘密:与其让模型更努力地“思考”,不如教它学会“作弊”。 01.愚蠢的天才:为什么要用算力去


如何在CentOS 7.9 服务器上配置并优化 Ceph 分布式存储集群,提升数据冗余与性能?
A5IDCCOM2026/1/8

本文基于A5IDC在真实生产环境(跨机房 Ceph 集群支撑虚拟机盘、对象存储及容灾复制)的实战经验,详细讲解如何从零部署 Ceph 集群在 CentOS 7.9 上,并通过硬件配置选择、网络优化、Ceph 参数调优等实用细节提升 数据冗余能力与性能表现。文章包含具体产品型号、系统配置表、命令示例与性能评估对比表,适合中大型数据中心储存架构实施。 一、背景与目标 随着业务系统对海量数据持久层的要求不断提升,我们需要一个高可靠、易扩展、具有自动自愈能力的分布式存储平台。Ceph 是开源生态


Git/Gitee/GitHub有什么区别
lifewange2025/12/31

Git、GitHub、Gitee(码云)三者核心区别 & 完整详解 你想弄清楚这三者的关系和差异,本质上Git 是「工具」,GitHub/Gitee 是「平台」,这是最核心的定位区别,三者不是同一维度的东西,先把这个核心逻辑吃透,所有差异就一目了然了。 ✅ 一、三者的「本质定位」(最核心,必记) 1. Git —— 本地的「版本控制系统」(纯软件 / 工具) Git 是一个免费、开源的分布式版本控制软件,它是一个安装在你电脑本地的程序 / 工具,不依赖任何网络、不依赖任何网站就能独立运行


Apache Tika XXE注入漏洞 | CVE-2025-66516 复现&研究
探索宇宙真理.2025/12/21

0x0 背景介绍 Tika Pdf Parser Module是Apache软件基金会开发的Java库,专用于解析PDF文件内容。核心功能包括文本提取、元数据解析及嵌入式对象处理,基于Apache Tika框架实现,依赖PDFBox等开源库。 Apache Tika的tika-core(1.13-3.2.1)、tika-pdf-module(2.0.0-3.2.1)和tika-parsers(1.13-1.28.5)模块存在严重XXE漏洞(跨平台),攻击者可通过构造PDF内的XFA文件实施XM

首页编辑器站点地图

本站内容在 CC BY-SA 4.0 协议下发布

Copyright © 2026 XYZ博客