Skip to main content

GRU (门控循环单元) 从零实现

GRU(Gated Recurrent Unit,门控循环单元)是 LSTM 的一个"简化进阶版"。

如果说普通 RNN 是"金鱼记忆"(容易忘事),LSTM 是"精密的文件柜"(功能强大但结构复杂),那么 GRU 就是一个更轻量、更高效的现代版文件柜

为什么需要 GRU?

我们知道普通 RNN 有梯度消失的问题,无法捕捉长距离依赖。LSTM 通过引入三个"门"(输入门、遗忘门、输出门)解决了这个问题,但参数很多,计算慢。

GRU (2014年提出) 在保持 LSTM 记忆能力的同时,将结构简化为两个门

名称作用
rtr_t重置门 (Reset Gate)决定在计算新候选记忆时,要忽略多少之前的隐状态
ztz_t更新门 (Update Gate)决定要保留多少旧状态,以及要写入多少新状态

核心区别: GRU 没有单独的"细胞状态 (Cell State)",它的隐状态 hth_t 既负责记忆,也负责输出。

数学原理

在时间步 tt,给定输入 xtx_t 和上一时刻隐状态 ht1h_{t-1}

步骤 1:计算门控

使用 Sigmoid 函数 (σ\sigma) 将值压缩到 [0,1][0, 1]

rt=σ(Wirxt+Whrht1+br)r_t = \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) zt=σ(Wizxt+Whzht1+bz)z_t = \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z)

步骤 2:计算候选隐状态

这里使用重置门 rtr_t。如果 rt0r_t \approx 0,意味着之前的记忆 ht1h_{t-1} 被"切断",模型就像在处理序列的第一个词一样:

h~t=tanh(Winxt+rt(Whnht1)+bn)\tilde{h}_t = \tanh(W_{in} x_t + r_t \odot (W_{hn} h_{t-1}) + b_n)

\odot 代表逐元素相乘)

步骤 3:计算最终隐状态

利用更新门 ztz_t 进行线性插值

  • 如果 zt1z_t \approx 1:主要保留旧记忆
  • 如果 zt0z_t \approx 0:主要使用新的候选状态
ht=(1zt)h~t+ztht1h_t = (1 - z_t) \odot \tilde{h}_t + z_t \odot h_{t-1}

信息流图示

                    ┌────────────────────────────────────────┐
│ GRU Cell │
│ │
h_{t-1} ─────────▶│ ┌──────┐ ┌──────┐ │
│ │ │ r_t │ │ z_t │ │
│ │ │重置门│ │更新门│ │
│ │ └──┬───┘ └──┬───┘ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────┐ │ │
│ │ │~h_t │ │ │──▶ h_t
└─────────────│─▶│候选态│──────┴──────────────────────│
│ └──────┘ ▲ │
│ │ │
x_t ──────────────│────────────────┘ │
│ │
└────────────────────────────────────────┘

PyTorch 从零实现

我们将手动实现上述逻辑,不使用 nn.GRU

import torch
import torch.nn as nn

class GRUFromScratch(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRUFromScratch, self).__init__()
self.hidden_size = hidden_size

# --- 1. 定义更新门 (Update Gate) 的权重 ---
self.x2z = nn.Linear(input_size, hidden_size) # x -> z
self.h2z = nn.Linear(hidden_size, hidden_size) # h -> z

# --- 2. 定义重置门 (Reset Gate) 的权重 ---
self.x2r = nn.Linear(input_size, hidden_size) # x -> r
self.h2r = nn.Linear(hidden_size, hidden_size) # h -> r

# --- 3. 定义候选状态 (Candidate State) 的权重 ---
self.x2n = nn.Linear(input_size, hidden_size) # x -> n
self.h2n = nn.Linear(hidden_size, hidden_size) # h -> n

# --- 4. 输出层 ---
self.output_layer = nn.Linear(hidden_size, output_size)

def forward(self, x, hidden=None):
"""
:param x: (batch_size, seq_len, input_size)
:param hidden: (batch_size, hidden_size)
"""
batch_size = x.size(0)
seq_len = x.size(1)

if hidden is None:
hidden = torch.zeros(batch_size, self.hidden_size).to(x.device)

outputs = []

# === 时间步循环 ===
for t in range(seq_len):
x_t = x[:, t, :]

# 1. 计算更新门 z_t
z_t = torch.sigmoid(self.x2z(x_t) + self.h2z(hidden))

# 2. 计算重置门 r_t
r_t = torch.sigmoid(self.x2r(x_t) + self.h2r(hidden))

# 3. 计算候选隐状态 n_t
# r_t 作用于 h 经过线性变换后的结果
h_reset = r_t * self.h2n(hidden)
n_t = torch.tanh(self.x2n(x_t) + h_reset)

# 4. 计算最终隐状态 h_t
# 软切换:z_t 控制保留旧信息,(1-z_t) 控制接受新信息
hidden = (1 - z_t) * n_t + z_t * hidden

# 5. 计算输出
out_t = self.output_layer(hidden)
outputs.append(out_t)

outputs = torch.stack(outputs, dim=1)
return outputs, hidden

测试代码

INPUT_SIZE = 10
HIDDEN_SIZE = 20
OUTPUT_SIZE = 5
BATCH_SIZE = 3
SEQ_LEN = 6

gru_model = GRUFromScratch(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE)
dummy_input = torch.randn(BATCH_SIZE, SEQ_LEN, INPUT_SIZE)

output, final_hidden = gru_model(dummy_input)

print(f"GRU 输出形状: {output.shape}") # [3, 6, 5]
print(f"GRU 最终隐状态: {final_hidden.shape}") # [3, 20]

输出:

GRU 输出形状: torch.Size([3, 6, 5])
GRU 最终隐状态: torch.Size([3, 20])

代码关键细节

Sigmoid vs Tanh 的选择

组件激活函数原因
门 (ztz_t, rtr_t)Sigmoid门的物理意义是"比例/开关",值必须在 [0,1][0, 1]
候选状态 (h~t\tilde{h}_t)Tanh真正的数据信息,值域 [1,1][-1, 1],保持梯度稳定

软门控 (Soft Gating)

hidden = (1 - z_t) * n_t + z_t * hidden

这行代码完全可微。模型可以通过反向传播自己学习:

  • 遇到句号时,zt0z_t \approx 0(全更新,遗忘上一句)
  • 遇到连词时,zt1z_t \approx 1(保持记忆)

GRU vs LSTM 对比

特性GRULSTM
门数量2 个3 个 + 细胞状态
参数量少(约为 LSTM 的 75%)
训练速度稍慢
表现小数据集/短序列往往更优超长序列/大数据集潜力更大

经验法则: 实践中通常先尝试 GRU(训练快,效果通常和 LSTM 差不多)。如果 GRU 效果遇到瓶颈,或者需要处理非常复杂的长依赖关系,再切换到 LSTM。

参考资料