循环神经网络与序列建模实战

自然语言处理、时间序列预测、语音识别——这些任务有一个共同特点:数据是序列化的,前面的信息会影响后面的理解。传统的前馈神经网络无法直接处理这种序列依赖关系,循环神经网络(RNN)正是为解决这个问题而生。

这篇文章会从RNN的基本原理讲起,深入分析梯度消失/爆炸问题的数学本质,然后讲解LSTM和GRU如何解决这个问题。最后通过代码实战展示如何用PyTorch实现LSTM进行文本分类,以及Seq2Seq+Attention的机器翻译实现。

RNN基础:处理序列数据

RNN的核心思想是”记忆”——每个时间步的隐藏状态不仅取决于当前输入,还取决于上一时间步的隐藏状态。这让它能够处理任意长度的序列,捕捉时间步之间的依赖关系。

RNN的前向传播

简单RNN(也称为Elman网络)的前向传播公式:

其中是时间步t的隐藏状态,是输入,是权重矩阵。tanh激活函数把输出压缩到[-1, 1]区间。

import torch
import torch.nn as nn
import numpy as np
 
# 手动实现简单RNN
class SimpleRNN(nn.Module):
    """手动实现简单RNN,理解其前向传播"""
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 权重初始化
        self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size) * 0.1)
        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1)
        self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size) * 0.1)
        self.b_h = nn.Parameter(torch.zeros(hidden_size))
        self.b_y = nn.Parameter(torch.zeros(output_size))
        
        self.tanh = nn.Tanh()
    
    def forward(self, x, h_prev=None):
        """
        x: (batch, input_size) 或 (batch, seq_len, input_size)
        返回: output, hidden_state
        """
        if h_prev is None:
            h_prev = torch.zeros(x.size(0), self.hidden_size)
        
        # 如果输入是3D (batch, seq, input),处理成序列
        if len(x.shape) == 3:
            outputs = []
            for t in range(x.size(1)):
                h_prev = self.tanh(x[:, t] @ self.W_xh + h_prev @ self.W_hh + self.b_h)
                outputs.append(h_prev)
            output = torch.stack(outputs, dim=1)
            return output, h_prev
        
        # 单步前向
        h = self.tanh(x @ self.W_xh + h_prev @ self.W_hh + self.b_h)
        y = h @ self.W_hy + self.b_y
        
        return y, h
 
# 测试RNN前向传播
print("=" * 60)
print("RNN 前向传播演示")
print("=" * 60)
 
rnn = SimpleRNN(input_size=10, hidden_size=20, output_size=5)
 
# 单步输入
x_single = torch.randn(2, 10)  # batch=2, input=10
h_prev = torch.randn(2, 20)
 
y, h = rnn(x_single, h_prev)
print(f"单步输入: {x_single.shape}")
print(f"隐藏状态输入: {h_prev.shape}")
print(f"输出: {y.shape}")
print(f"隐藏状态输出: {h.shape}")
 
# 序列输入
x_seq = torch.randn(2, 5, 10)  # batch=2, seq_len=5, input=10
output_seq, final_h = rnn(x_seq)
print(f"\n序列输入: {x_seq.shape}")
print(f"序列输出: {output_seq.shape}")  # (batch, seq, hidden)
print(f"最终隐藏状态: {final_h.shape}")
 
# PyTorch内置RNN对比
rnn_pytorch = nn.RNN(input_size=10, hidden_size=20, num_layers=1, batch_first=True)
output_pytorch, h_pytorch = rnn_pytorch(x_seq)
print(f"\nPyTorch RNN输出: {output_pytorch.shape}")
print(f"PyTorch 隐藏状态: {h_pytorch.shape}")

梯度消失与梯度爆炸:RNN的致命缺陷

RNN的训练通过时间反向传播(BPTT)。问题在于,梯度在反向传播过程中需要经过每一个时间步,如果路径上的梯度都小于1,梯度会指数衰减到接近0(梯度消失);如果存在大于1的梯度,梯度会指数增长(梯度爆炸)。

数学本质

考虑RNN的一个简化版本:

假设没有输入,递归应用T步:

求梯度:

的范数(谱半径)决定了梯度行为:

  • 如果λ_max(W) < 1,梯度指数衰减
  • 如果λ_max(W) > 1,梯度指数增长

实际RNN中,梯度还涉及激活函数的导数。tanh的导数最大值为1,sigmoid的导数最大值为0.25,这进一步加剧了梯度消失。

梯度裁剪:治标不治本

梯度裁剪是应对梯度爆炸的常用技巧:当梯度范数超过阈值时,将梯度缩放到阈值范围内。这能防止训练崩溃,但无法解决长期依赖问题——梯度消失仍然存在。

# 梯度裁剪示例
print("=" * 60)
print("梯度消失与梯度爆炸")
print("=" * 60)
 
# 模拟RNN中的梯度传播
def simulate_gradient_flow(W, steps=50, use_clip=False, clip_value=5.0):
    """模拟梯度在RNN中随时间的传播"""
    gradients = []
    grad = torch.eye(W.size(0))  # 初始梯度
    
    for _ in range(steps):
        grad = grad @ W.T  # 反向传播
        
        if use_clip:
            grad_norm = torch.norm(grad)
            if grad_norm > clip_value:
                grad = grad * (clip_value / grad_norm)
        
        gradients.append(torch.norm(grad).item())
    
    return gradients
 
# 测试不同W的梯度行为
np.random.seed(42)
torch.manual_seed(42)
 
W_stable = torch.randn(10, 10) * 0.1  # 小权重,梯度消失
W_unstable = torch.randn(10, 10) * 1.5  # 大权重,梯度爆炸
 
grad_stable = simulate_gradient_flow(W_stable)
grad_unstable = simulate_gradient_flow(W_unstable)
grad_clipped = simulate_gradient_flow(W_unstable, use_clip=True)
 
plt.figure(figsize=(12, 4))
 
plt.subplot(1, 2, 1)
plt.plot(grad_stable, label='Stable (small weights)', color='green')
plt.plot(grad_unstable, label='Unstable (large weights)', color='red')
plt.plot(grad_clipped, label='Clipped', color='blue')
plt.xlabel('Time Step')
plt.ylabel('Gradient Norm')
plt.title('Gradient Magnitude Over Time')
plt.legend()
plt.grid(True, alpha=0.3)
 
plt.subplot(1, 2, 2)
plt.plot(np.log10(grad_stable), label='Stable', color='green')
plt.plot(np.log10(grad_unstable), label='Unstable', color='red')
plt.xlabel('Time Step')
plt.ylabel('Log10(Gradient Norm)')
plt.title('Log Scale - Gradient Magnitude')
plt.legend()
plt.grid(True, alpha=0.3)
 
plt.tight_layout()
plt.savefig('gradient_flow.png', dpi=150)
 
# 梯度裁剪在训练中的使用
print("""
训练RNN时的梯度裁剪:
```python
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
 
# 或者
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)

""")


## LSTM:长短期记忆网络

LSTM(Long Short-Term Memory)是Hochreiter和Schmidhuber在1997年提出的,专门解决长期依赖问题。其核心是引入门控机制,允许网络学习何时记住、何时遗忘。

### 三大门的作用

LSTM引入了三个门控单元:

**遗忘门(Forget Gate)**:决定从细胞状态中丢弃什么信息。它查看$h_{t-1}$和$x_t$,输出一个0到1之间的数。1表示"完全保留",0表示"完全丢弃"。

**输入门(Input Gate)**:决定把什么新信息存储到细胞状态中。它由两部分组成:一个sigmoid层决定哪些值要更新,一个tanh层创建候选值向量。

**输出门(Output Gate)**:决定输出什么。它基于细胞状态,但经过过滤——先运行sigmoid决定细胞状态的哪些部分输出,然后把细胞状态通过tanh处理,再与sigmoid门的输出相乘。

```python
# 手动实现LSTM
class LSTMCell(nn.Module):
    """手动实现LSTM Cell,理解门控机制"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 合并输入和隐藏状态到同一个矩阵,减少计算
        self.W = nn.Parameter(torch.randn(input_size + hidden_size, 4 * hidden_size) * 0.1)
        self.b = nn.Parameter(torch.zeros(4 * hidden_size))
    
    def forward(self, x, state):
        """
        x: (batch, input_size)
        state: tuple of (h, c) each (batch, hidden_size)
        """
        h, c = state
        
        # 拼接输入和隐藏状态
        combined = torch.cat([x, h], dim=1)
        
        # 计算四个门的值
        gates = combined @ self.W + self.b
        
        # 分割成四个门
        i, f, g, o = gates.chunk(4, dim=1)
        
        # Sigmoid激活
        i = torch.sigmoid(i)  # 输入门
        f = torch.sigmoid(f)  # 遗忘门
        o = torch.sigmoid(o)  # 输出门
        g = torch.tanh(g)     # 候选值
        
        # 更新细胞状态
        c_new = f * c + i * g
        
        # 计算新的隐藏状态
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new

class ManualLSTM(nn.Module):
    """手动实现完整LSTM"""
    def __init__(self, input_size, hidden_size, num_layers=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.cells = nn.ModuleList([
            LSTMCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
    
    def forward(self, x, state=None):
        """
        x: (batch, seq_len, input_size)
        state: tuple of (h, c) each (num_layers, batch, hidden_size)
        """
        batch, seq_len, _ = x.shape
        
        # 初始化状态
        if state is None:
            h = [torch.zeros(batch, self.hidden_size) for _ in range(self.num_layers)]
            c = [torch.zeros(batch, self.hidden_size) for _ in range(self.num_layers)]
        else:
            h, c = state
        
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :]
            for layer in range(self.num_layers):
                h[layer], c[layer] = self.cells[layer](
                    x_t if layer == 0 else h[layer-1],
                    (h[layer], c[layer])
                )
            outputs.append(h[-1])
        
        output = torch.stack(outputs, dim=1)  # (batch, seq, hidden)
        final_state = (torch.stack(h, dim=0), torch.stack(c, dim=0))
        
        return output, final_state

# PyTorch内置LSTM对比
print("=" * 60)
print("LSTM 门控机制演示")
print("=" * 60)

# 使用PyTorch的LSTMCell展示门控行为
lstm_cell = nn.LSTMCell(input_size=10, hidden_size=20)
print(f"LSTMCell 参数量: {sum(p.numel() for p in lstm_cell.parameters()):,}")

# 可视化门控机制
def visualize_lstm_gates():
    """展示LSTM各门的作用"""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    axes[0, 0].text(0.5, 0.7, "遗忘门 (Forget Gate)", fontsize=16, ha='center')
    axes[0, 0].text(0.5, 0.5, "σ([xₜ, hₜ₋₁] · Wf + bf)", fontsize=12, ha='center')
    axes[0, 0].text(0.5, 0.3, "决定丢弃哪些旧信息", fontsize=12, ha='center', 
                   style='italic', color='blue')
    axes[0, 0].axis('off')
    
    axes[0, 1].text(0.5, 0.7, "输入门 (Input Gate)", fontsize=16, ha='center')
    axes[0, 1].text(0.5, 0.5, "i = σ([xₜ, hₜ₋₁] · Wi + bi)", fontsize=12, ha='center')
    axes[0, 1].text(0.5, 0.3, "g = tanh([xₜ, hₜ₋₁] · Wg + bg)", fontsize=12, ha='center')
    axes[0, 1].text(0.5, 0.1, "决定添加哪些新信息", fontsize=12, ha='center',
                   style='italic', color='blue')
    axes[0, 1].axis('off')
    
    axes[1, 0].text(0.5, 0.8, "细胞状态更新", fontsize=16, ha='center')
    axes[1, 0].text(0.5, 0.5, "cₜ = f ⊙ cₜ₋₁ + i ⊙ g", fontsize=14, ha='center')
    axes[1, 0].text(0.5, 0.2, "遗忘门×旧状态 + 输入门×候选值", fontsize=12, ha='center')
    axes[1, 0].axis('off')
    
    axes[1, 1].text(0.5, 0.7, "输出门 (Output Gate)", fontsize=16, ha='center')
    axes[1, 1].text(0.5, 0.5, "o = σ([xₜ, hₜ₋₁] · Wo + bo)", fontsize=12, ha='center')
    axes[1, 1].text(0.5, 0.3, "hₜ = o ⊙ tanh(cₜ)", fontsize=14, ha='center')
    axes[1, 1].text(0.5, 0.1, "决定输出什么信息", fontsize=12, ha='center',
                   style='italic', color='blue')
    axes[1, 1].axis('off')
    
    plt.suptitle("LSTM 三大门机制", fontsize=18, y=1.02)
    plt.tight_layout()
    plt.savefig('lstm_gates.png', dpi=150, bbox_inches='tight')

visualize_lstm_gates()

LSTM如何解决梯度消失

LSTM的门控机制允许梯度直接流动,而不经过复杂的非线性变换。

关键在于细胞状态的加法更新:

这个加法操作意味着梯度可以通过加法直接传回来,不会因为链式法则乘以多个权重矩阵。这就像一条”高速公路”,让梯度可以无衰减地流动。

遗忘门是一个可以学习到的值(0到1之间),网络可以学习让它接近1,在需要长期依赖时保持梯度流动。

# 展示LSTM的梯度流
def lstm_gradient_flow():
    """展示LSTM中梯度如何流动"""
    print("""
    标准RNN的梯度流:
    ┌─────────────────────────────────────────────────┐
    │ h₀ → h₁ → h₂ → ... → hₜ → ... → hT           │
    │ ↓     ↓     ↓           ↓           ↓          │
    │ ∂L/∂h₀ ← ∂L/∂h₁ ← ∂L/∂h₂ ← ... ← ∂L/∂hT     │
    │          ↑每个箭头都是矩阵乘法+激活函数导数        │
    └─────────────────────────────────────────────────┘
    梯度需要乘以 Wᵀ 和 激活函数导数
    如果 |λₘₐₓ(W)| < 1: 梯度指数衰减
    如果 |λₘₐₓ(W)| > 1: 梯度指数增长
 
    LSTM的梯度流:
    ┌─────────────────────────────────────────────────┐
    │ c₀ ──────────────────→ c₁ ──→ ... ──→ cT       │
    │ ↓      ↓                     ↓           ↓     │
    │ h₀     h₁                    hₜ          hT     │
    │                    ↑                        ↑   │
    │              加法连接: ∂cₜ/∂cₜ₋₁ = fₜ (可学习) │
    └─────────────────────────────────────────────────┘
    细胞状态通过加法更新,梯度直接流动
    "细胞状态传送带"让长期依赖成为可能
    """)
 
lstm_gradient_flow()

GRU:LSTM的简化版

GRU(Gated Recurrent Unit)是Cho等人在2014年提出的,是对LSTM的简化。它把遗忘门和输入门合并成一个”更新门”,同时合并了细胞状态和隐藏状态。

GRU vs LSTM

GRU只有两个门:更新门和重置门。相比LSTM:

  • 参数量更少(更少的矩阵运算)
  • 训练更快
  • 在很多任务上效果与LSTM相当
  • 长期依赖能力略弱于LSTM(但仍然很好)

选择建议:优先尝试LSTM;如果追求速度或数据少,用GRU;如果模型仍然欠拟合,考虑LSTM。

# GRU实现
class GRUCell(nn.Module):
    """手动实现GRU Cell"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.W_z = nn.Parameter(torch.randn(input_size + hidden_size, hidden_size) * 0.1)
        self.W_r = nn.Parameter(torch.randn(input_size + hidden_size, hidden_size) * 0.1)
        self.W_h = nn.Parameter(torch.randn(input_size + hidden_size, hidden_size) * 0.1)
        self.b_z = nn.Parameter(torch.zeros(hidden_size))
        self.b_r = nn.Parameter(torch.zeros(hidden_size))
        self.b_h = nn.Parameter(torch.zeros(hidden_size))
    
    def forward(self, x, h_prev):
        combined = torch.cat([x, h_prev], dim=1)
        
        # 更新门
        z = torch.sigmoid(combined @ self.W_z + self.b_z)
        
        # 重置门
        r = torch.sigmoid(combined @ self.W_r + self.b_r)
        
        # 候选隐藏状态
        h_tilde = torch.tanh(
            torch.cat([x, r * h_prev], dim=1) @ self.W_h + self.b_h
        )
        
        # 新的隐藏状态
        h_new = (1 - z) * h_prev + z * h_tilde
        
        return h_new
 
# 对比LSTM和GRU
print("=" * 60)
print("LSTM vs GRU 对比")
print("=" * 60)
 
lstm = nn.LSTMCell(input_size=100, hidden_size=256)
gru = nn.GRUCell(input_size=100, hidden_size=256)
 
lstm_params = sum(p.numel() for p in lstm.parameters())
gru_params = sum(p.numel() for p in gru.parameters())
 
print(f"LSTM Cell 参数量: {lstm_params:,}")
print(f"GRU Cell 参数量: {gru_params:,}")
print(f"参数量减少: {(lstm_params - gru_params) / lstm_params * 100:.1f}%")
 
comparison_table = """
┌─────────────┬──────────────────────┬──────────────────────┐
│   特性      │        LSTM          │        GRU           │
├─────────────┼──────────────────────┼──────────────────────┤
│ 门数量      │ 3个 (遗忘/输入/输出)   │ 2个 (更新/重置)      │
│ 参数量      │ 多 (~4×4HH)          │ 少 (~3×4HH)          │
│ 长期记忆    │ 强                    │ 较强                  │
│ 训练速度    │ 较慢                  │ 较快                  │
│ 细胞状态    │ 有                    │ 无 (隐藏状态直接更新)  │
│ 适用场景    │ 长序列、复杂依赖       │ 中等长度序列          │
└─────────────┴──────────────────────┴──────────────────────┘
"""
print(comparison_table)

序列到序列(Seq2Seq)架构

Seq2Seq是处理序列到序列转换任务的经典架构,比如机器翻译、文本摘要、对话生成。核心思想是编码器-解码器模式。

编码器-解码器

编码器(Encoder):把输入序列编码成一个固定长度的向量(通常是最后一个隐藏状态)。这个向量包含了整个输入序列的信息。

解码器(Decoder):基于编码器的输出,逐步生成目标序列。解码通常从特殊起始符开始,直到生成结束符。

class Encoder(nn.Module):
    """Seq2Seq 编码器"""
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
    
    def forward(self, x):
        """
        x: (batch, seq_len) 输入序列
        """
        embedded = self.embedding(x)  # (batch, seq, embed)
        outputs, (hidden, cell) = self.lstm(embedded)
        # outputs: 所有时间步的隐藏状态
        # hidden, cell: 最后一层的隐藏状态 (用于初始化解码器)
        return outputs, hidden, cell
 
class Decoder(nn.Module):
    """Seq2Seq 解码器"""
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, hidden, cell):
        """
        x: (batch, 1) 当前输入token
        hidden, cell: 编码器的最终状态
        """
        embedded = self.embedding(x)  # (batch, 1, embed)
        outputs, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        outputs = self.fc(outputs.squeeze(1))  # (batch, vocab)
        return outputs, hidden, cell
 
class Seq2Seq(nn.Module):
    """完整的Seq2Seq模型"""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, source, target, teacher_forcing_ratio=0.5):
        """
        source: (batch, src_len) 源序列
        target: (batch, tgt_len) 目标序列
        """
        batch_size = source.size(0)
        tgt_len = target.size(1)
        tgt_vocab_size = self.decoder.fc.out_features
        
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(source.device)
        
        # 编码
        _, hidden, cell = self.encoder(source)
        
        # 解码
        decoder_input = target[:, 0:1]  # 起始符
        for t in range(tgt_len):
            output, hidden, cell = self.decoder(decoder_input, hidden, cell)
            outputs[:, t] = output
            
            # Teacher Forcing: 以一定概率使用真实目标
            teacher_force = np.random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            decoder_input = target[:, t:t+1] if teacher_force else top1.unsqueeze(1)
        
        return outputs
 
print("Seq2Seq 模型创建完成")

注意力机制

注意力机制是Seq2Seq的革命性改进。原始Seq2Seq的问题是所有信息必须压缩到一个固定长度的向量里,这限制了模型处理长序列的能力。注意力机制让解码器在每个时间步都能”看”到编码器的所有隐藏状态,并根据当前上下文动态分配权重。

注意力计算过程

  1. 计算Query-Key相似度:
  2. Softmax归一化得到注意力权重:
  3. 加权求和得到上下文向量:
  4. 上下文向量与当前解码器状态拼接后预测
class Attention(nn.Module):
    """Bahdanau Attention (Additive Attention)"""
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.randn(hidden_size))
    
    def forward(self, hidden, encoder_outputs):
        """
        hidden: (1, batch, hidden_size) 当前解码器隐藏状态
        encoder_outputs: (batch, src_len, hidden_size) 编码器所有隐藏状态
        """
        src_len = encoder_outputs.size(1)
        
        # 重复hidden到src_len
        hidden = hidden.squeeze(0).unsqueeze(1).repeat(1, src_len, 1)  # (batch, src_len, hidden)
        
        # 计算能量
        energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))  # (batch, src_len, hidden)
        energy = energy.permute(0, 2, 1)  # (batch, hidden, src_len)
        
        # 计算注意力权重
        v = self.v.unsqueeze(0).unsqueeze(0).repeat(encoder_outputs.size(0), 1, 1)  # (batch, 1, hidden)
        attention = torch.bmm(v, energy).squeeze(1)  # (batch, src_len)
        attention_weights = F.softmax(attention, dim=1)  # (batch, src_len)
        
        # 加权求和
        encoder_outputs = encoder_outputs.permute(0, 2, 1)  # (batch, hidden, src_len)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # (batch, hidden)
        
        return context, attention_weights
 
class AttentionDecoder(nn.Module):
    """带注意力的解码器"""
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(hidden_size)
        self.lstm = nn.LSTM(embed_size + hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, vocab_size)
    
    def forward(self, x, hidden, cell, encoder_outputs):
        """
        x: (batch, 1) 当前输入token
        hidden: (num_layers, batch, hidden)
        encoder_outputs: (batch, src_len, hidden)
        """
        embedded = self.embedding(x)  # (batch, 1, embed)
        
        # 计算注意力
        context, attention_weights = self.attention(hidden[-1:], encoder_outputs)
        context = context.unsqueeze(1)  # (batch, 1, hidden)
        
        # 拼接embedded和context
        lstm_input = torch.cat([embedded, context], dim=2)  # (batch, 1, embed+hidden)
        
        # LSTM前向
        outputs, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        
        # 预测
        predictions = self.fc(torch.cat([outputs.squeeze(1), context.squeeze(1)], dim=1))
        
        return predictions, hidden, cell, attention_weights
 
# 注意力可视化
def visualize_attention():
    """可视化翻译任务中的注意力权重"""
    print("注意力机制说明:")
    attention_demo = """
    示例:德语 → 英语 翻译
    
    输入(德语): "Das katze sitzt auf dem tisch"
    输出(英语): "The cat sits on the table"
    
    翻译 "cat" 时的注意力权重:
    ┌────────────────────────────┐
    │  das      0.02            │
    │  katze    0.85  ← 最高!   │
    │  sitzt    0.05            │
    │  auf      0.03            │
    │  dem      0.02            │
    │  tisch    0.03            │
    └────────────────────────────┘
    模型正确地关注到了源句中的 "katze"!
    
    翻译 "sits" 时的注意力权重:
    ┌────────────────────────────┐
    │  das      0.01            │
    │  katze    0.08            │
    │  sitzt    0.88  ← 最高!   │
    │  auf      0.01            │
    │  dem      0.01            │
    │  tisch    0.01            │
    └────────────────────────────┘
    """
    print(attention_demo)
 
visualize_attention()

代码实战:LSTM文本分类

现在来一个完整的实战:用LSTM做文本分类。这个任务比翻译简单很多,适合作为入门练习。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
 
# 文本分类数据准备
class TextClassificationDataset(Dataset):
    """简单的文本分类数据集"""
    def __init__(self, texts, labels, word2idx, max_len=100):
        self.texts = texts
        self.labels = labels
        self.word2idx = word2idx
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Tokenize and convert to indices
        tokens = text.lower().split()
        indices = [self.word2idx.get(token, 0) for token in tokens[:self.max_len]]
        
        # Padding
        if len(indices) < self.max_len:
            indices += [0] * (self.max_len - len(indices))
        
        return torch.tensor(indices, dtype=torch.long), torch.tensor(label, dtype=torch.long)
 
class TextClassificationLSTM(nn.Module):
    """用于文本分类的LSTM模型"""
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, num_classes, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, 
                           batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * 2, num_classes)  # *2 for bidirectional
    
    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed)
        
        # LSTM
        lstm_out, (hidden, cell) = self.lstm(embedded)
        
        # 使用最后时刻的隐藏状态(双向所以是hidden_size*2)
        # hidden: (num_layers*2, batch, hidden_size)
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)  # (batch, hidden_size*2)
        
        hidden = self.dropout(hidden)
        output = self.fc(hidden)  # (batch, num_classes)
        
        return output
 
# 训练函数
def train_lstm_classifier():
    """LSTM文本分类器训练示例"""
    print("=" * 60)
    print("LSTM 文本分类实战")
    print("=" * 60)
    
    # 模拟数据
    texts = [
        "This movie is fantastic and I love it",
        "What a terrible film, completely wasted",
        "Great acting and wonderful storyline",
        "Boring and slow, fell asleep twice",
        "Best movie I have ever seen",
        "Horrible waste of time",
        "Enjoyed every minute of this film",
        "So bad, I want my money back"
    ] * 50  # 重复增加数据量
    
    labels = [1, 0, 1, 0, 1, 0, 1, 0] * 50
    
    # 构建词汇表
    word_counts = Counter()
    for text in texts:
        word_counts.update(text.lower().split())
    
    vocab = ['<PAD>', '<UNK>'] + [word for word, count in word_counts.most_common(5000)]
    word2idx = {word: idx for idx, word in enumerate(vocab)}
    
    print(f"词汇表大小: {len(word2idx)}")
    
    # 创建数据集
    dataset = TextClassificationDataset(texts, labels, word2idx, max_len=20)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16)
    
    # 模型
    model = TextClassificationLSTM(
        vocab_size=len(word2idx),
        embed_size=128,
        hidden_size=64,
        num_layers=2,
        num_classes=2,
        dropout=0.3
    )
    
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
    
    # 训练配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    # 训练循环
    num_epochs = 20
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        # 训练
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
        
        # 验证
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        train_acc = train_correct / train_total
        val_acc = val_correct / val_total
        
        scheduler.step(val_loss)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")
            print()
    
    print(f"最佳验证准确率: {best_val_acc:.4f}")
    
    # 预测示例
    model.eval()
    test_texts = [
        "This film is absolutely wonderful",
        "What a complete waste of time"
    ]
    
    print("\n预测示例:")
    for text in test_texts:
        tokens = text.lower().split()
        indices = [word2idx.get(token, 1) for token in tokens[:20]]
        indices += [0] * (20 - len(indices))
        
        with torch.no_grad():
            x = torch.tensor([indices]).to(device)
            output = model(x)
            prob = torch.softmax(output, dim=1)
            pred = output.argmax(1).item()
        
        print(f"Text: '{text}'")
        print(f"Prediction: {'Positive' if pred == 1 else 'Negative'} ({prob[0][pred].item():.4f})")
        print()
 
train_lstm_classifier()

代码实战:机器翻译Seq2Seq+Attention

最后一个实战:实现一个完整的机器翻译模型,包含编码器、解码器和注意力机制。

def machine_translation_example():
    """机器翻译Seq2Seq+Attention实战"""
    print("=" * 60)
    print("机器翻译 Seq2Seq + Attention 实战")
    print("=" * 60)
    
    translation_code = '''
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import random
 
# 数据准备
class TranslationDataset(torch.utils.data.Dataset):
    def __init__(self, src_texts, tgt_texts, src_vocab, tgt_vocab, max_len=50):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len
    
    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, idx):
        src = self.src_vocab.encode(self.src_texts[idx], self.max_len)
        tgt = self.tgt_vocab.encode(self.tgt_texts[idx], self.max_len + 1)
        return src, tgt
 
class Vocab:
    def __init__(self):
        self.word2idx = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.n_words = 4
    
    def build_vocab(self, texts):
        for text in texts:
            for word in text.split():
                if word not in self.word2idx:
                    self.word2idx[word] = self.n_words
                    self.idx2word[self.n_words] = word
                    self.n_words += 1
    
    def encode(self, text, max_len):
        indices = [self.word2idx.get(w, 3) for w in text.split()[:max_len-1]]
        indices += [2]  # EOS
        return indices
    
    def decode(self, indices):
        words = []
        for idx in indices:
            if idx == 2:  # EOS
                break
            if idx not in [0, 1]:  # PAD, SOS
                words.append(self.idx2word.get(idx, "<UNK>"))
        return " ".join(words)
 
# 带注意力的Seq2Seq模型
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True,
                           bidirectional=True, dropout=dropout if num_layers > 1 else 0)
        self.fc = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, lengths):
        embedded = self.dropout(self.embedding(x))
        packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        outputs, (hidden, cell) = self.lstm(packed)
        outputs, _ = pad_packed_sequence(outputs, batch_first=True)
        
        # 合并双向隐藏状态
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        hidden = self.fc(hidden)
        cell = torch.cat([cell[-2], cell[-1]], dim=1)
        cell = self.fc(cell)
        
        return outputs, hidden.unsqueeze(0), cell.unsqueeze(0)
 
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 3, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, hidden, encoder_outputs, mask):
        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)
        
        # hidden: (1, batch, hidden), encoder_outputs: (batch, src_len, hidden)
        hidden = hidden.repeat(1, src_len, 1)  # (batch, src_len, hidden)
        
        energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))
        attention = self.v(energy).squeeze(2)  # (batch, src_len)
        
        attention = attention.masked_fill(mask, float('-inf'))
        attention_weights = torch.softmax(attention, dim=1)
        
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # (batch, hidden)
        
        return context, attention_weights
 
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.attention = Attention(hidden_size)
        self.lstm = nn.LSTM(embed_size + hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, hidden, cell, encoder_outputs, mask):
        x = x.unsqueeze(1)  # (batch, 1)
        embedded = self.dropout(self.embedding(x))
        
        context, attention_weights = self.attention(hidden, encoder_outputs, mask)
        
        lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        
        output = self.fc(torch.cat([output.squeeze(1), context], dim=1))
        
        return output, hidden, cell, attention_weights
 
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        tgt_vocab_size = self.decoder.fc.out_features
        
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(src.device)
        attention_weights_all = []
        
        # 编码
        encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
        
        # 创建mask
        mask = (src == 0)
        
        # 解码
        decoder_input = tgt[:, 0]  # SOS
        for t in range(1, tgt_len):
            output, hidden, cell, attention_weights = self.decoder(
                decoder_input, hidden, cell, encoder_outputs, mask
            )
            outputs[:, t] = output
            attention_weights_all.append(attention_weights)
            
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            decoder_input = tgt[:, t] if teacher_force else top1
        
        return outputs, attention_weights_all
    
    def translate(self, src, src_lengths, max_len=50, sos=1, eos=2):
        """推理时的翻译"""
        self.eval()
        with torch.no_grad():
            encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
            mask = (src == 0)
            
            decoder_input = torch.tensor([sos]).to(src.device)
            translations = []
            
            for _ in range(max_len):
                output, hidden, cell, _ = self.decoder(
                    decoder_input, hidden, cell, encoder_outputs, mask
                )
                
                top1 = output.argmax(1).item()
                
                if top1 == eos:
                    break
                
                translations.append(top1)
                decoder_input = torch.tensor([top1]).to(src.device)
            
            return translations
 
# 训练函数
def train_seq2seq(model, train_loader, num_epochs=50, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for src, tgt in train_loader:
            src, tgt = src.to(device), tgt.to(device)
            src_lengths = (src != 0).sum(dim=1)
            
            optimizer.zero_grad()
            outputs, _ = model(src, src_lengths, tgt)
            
            # 忽略PAD和SOS
            outputs = outputs[:, 1:].contiguous().view(-1, outputs.size(-1))
            tgt = tgt[:, 1:].contiguous().view(-1)
            
            loss = criterion(outputs, tgt)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
 
# 使用示例
print("Seq2Seq机器翻译模型创建完成")
print("""
使用流程:
1. 准备平行语料(中英对照句子)
2. 构建词汇表
3. 创建模型 (Encoder + Attention + Decoder)
4. 训练
5. 翻译新句子
""")
'''
    print(translation_code)
 
machine_translation_example()

总结

这篇文章深入讲解了循环神经网络的核心知识。

RNN通过隐藏状态实现了序列建模,但标准RNN存在梯度消失/爆炸问题,难以学习长期依赖。LSTM和GRU通过引入门控机制解决了这个问题。LSTM的遗忘门、输入门、输出门让网络能够选择性地记住或忘记信息;GRU则是LSTM的简化版本,用更新门和重置门达到类似效果。

Seq2Seq架构是处理序列到序列任务的基础,编码器把输入序列编码成向量,解码器基于这个向量生成输出。注意力机制的引入让解码器在每个时间步都能”看到”编码器的所有隐藏状态,大大提升了长序列处理能力。

实际应用中,Transformer已经很大程度上取代了RNN。但理解RNN和LSTM仍然重要,因为它们在某些场景下仍然有效,而且理解RNN有助于理解Attention机制的本质。

LSTM文本分类和机器翻译的实战代码展示了如何把理论付诸实践。建议你找一些公开数据集(如IMDB电影评论分类、WMT翻译数据)实际跑一跑代码,感受一下模型的行为。


本文为深度学习实战指南系列文章,主要涵盖循环神经网络与序列建模的核心知识点和实践技巧。