Transformer 模型

Transformer 模型是近年来深度学习领域的重要突破,最初由 Vaswani 等人在论文《Attention is All You Need》(2017)中提出,主要用于自然语言处理(NLP),现已广泛应用于计算机视觉、语音处理等领域。PyTorch 提供了便捷的方式通过 torch.nn 模块实现 Transformer 模型。以下是关于 Transformer 模型的详细说明,涵盖其架构、PyTorch 实现、核心组件及使用示例。


1. Transformer 模型概述

Transformer 模型摒弃了传统的循环神经网络(RNN)和卷积神经网络(CNN),完全基于注意力机制(Attention Mechanism),通过自注意力(Self-Attention)捕捉序列中元素间的关系。其核心特点包括:

  • 并行化:不像 RNN 按序处理,Transformer 支持并行计算,加速训练。
  • 长距离依赖:通过注意力机制有效捕捉远距离的上下文关系。
  • 模块化:编码器(Encoder)和解码器(Decoder)堆叠组成,适用于序列到序列任务。

Transformer 常用于:

  • 机器翻译(如英译中)
  • 文本生成(如 GPT 系列)
  • 分类任务(如 BERT)
  • 视觉任务(如 Vision Transformer, ViT)

2. Transformer 架构

Transformer 模型由编码器(Encoder)和解码器(Decoder)组成,每部分包含多个相同的层(Layer)。

(1) 编码器(Encoder)

  • 输入:序列(如单词、标记)经过嵌入(Embedding)后,添加位置编码(Positional Encoding)。
  • 结构:N 个编码器层,每个层包含:
  • 多头自注意力(Multi-Head Self-Attention):捕捉输入序列中各元素间的关系。
  • 前馈神经网络(Feed-Forward Neural Network, FFN):逐位置应用全连接层。
  • 残差连接与层归一化(Add & Norm):每子层后添加残差连接和 LayerNorm。
  • 输出:编码后的特征表示,用于解码器或直接用于任务(如分类)。

(2) 解码器(Decoder)

  • 输入:目标序列(偏移一位)或掩码输入。
  • 结构:N 个解码器层,每个层包含:
  • 掩码多头自注意力:防止关注未来的标记(用于自回归任务)。
  • 多头注意力(Encoder-Decoder Attention):结合编码器输出。
  • 前馈神经网络:与编码器类似。
  • 残差连接与层归一化
  • 输出:生成目标序列的概率分布。

(3) 关键组件

  • 注意力机制
  • Scaled Dot-Product Attention:计算查询(Query)、键(Key)、值(Value)之间的点积,缩放后通过 Softmax 归一化。
    [
    \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
    ]
  • 多头注意力:将注意力分成多个并行头,增强表达能力。
  • 位置编码:为序列添加位置信息(固定正弦/余弦函数或可学习参数)。
  • 前馈网络:每个位置独立应用全连接层,通常为两层 MLP。
  • 层归一化:稳定训练,加速收敛。

3. PyTorch 中的 Transformer

PyTorch 的 torch.nn 模块提供了 Transformer 的核心组件,包括 nn.Transformernn.MultiheadAttention 等,方便直接使用或自定义。

(1) 核心类和函数

  • nn.Transformer:完整的 Transformer 模型,包含编码器和解码器。
  nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
  • d_model:模型维度(嵌入和注意力维度)。
  • nhead:多头注意力的头数。
  • num_encoder_layers / num_decoder_layers:编码器/解码器层数。
  • nn.MultiheadAttention:多头注意力机制。
  nn.MultiheadAttention(embed_dim=512, num_heads=8)
  • nn.TransformerEncoder / nn.TransformerDecoder:单独的编码器或解码器模块。
  encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  • 位置编码:需手动实现或使用第三方库。

(2) 简单实现:Transformer 模型

以下是一个简单的 Transformer 模型示例,用于序列到序列任务(如翻译):

import torch
import torch.nn as nn
import math

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Transformer 模型
class SimpleTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_layers=6):
        super(SimpleTransformer, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_layers, num_layers)
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.d_model = d_model

    def forward(self, src, tgt):
        src = self.src_embedding(src) * math.sqrt(self.d_model)
        tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

# 示例数据
src = torch.randint(0, 100, (10, 32))  # (seq_len, batch_size)
tgt = torch.randint(0, 100, (10, 32))
model = SimpleTransformer(src_vocab_size=100, tgt_vocab_size=100)
output = model(src, tgt)

(3) 训练流程

import torch.optim as optim

# 模型、损失函数和优化器
model = SimpleTransformer(src_vocab_size=100, tgt_vocab_size=100).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 训练循环
model.train()
for epoch in range(5):
    optimizer.zero_grad()
    output = model(src.to('cuda'), tgt.to('cuda')[:-1])  # 目标序列偏移
    loss = criterion(output.view(-1, 100), tgt[1:].view(-1).to('cuda'))
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

4. 注意事项

  • 掩码(Mask)
  • 源序列掩码:处理填充(padding)部分,nn.Transformersrc_masksrc_key_padding_mask
  • 目标序列掩码:防止解码器关注未来标记,tgt_mask 使用上三角矩阵。
    python def generate_square_subsequent_mask(sz): mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) return mask
  • 批处理:输入形状为 (seq_len, batch_size, d_model),需注意转置。
  • 位置编码:必须添加,否则 Transformer 无法区分序列中元素的位置。
  • 超参数
  • d_model 需被 nhead 整除。
  • 适当调整 num_layersdropout(默认 0.1)以平衡性能和过拟合。

5. 进阶用法

  • 预训练模型:使用 Hugging Face 的 transformers 库加载预训练 Transformer(如 BERT、GPT):
  from transformers import BertModel, BertTokenizer
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  model = BertModel.from_pretrained('bert-base-uncased')
  • Vision Transformer (ViT):将图像分块(patch)后输入 Transformer 编码器。
  • 自定义注意力:通过继承 nn.MultiheadAttention 实现变种注意力机制。
  • 混合精度训练
  from torch.cuda.amp import autocast
  with autocast():
      output = model(src, tgt)

6. 常见问题

  • 内存占用:Transformer 模型参数量大,需优化批大小或使用梯度累积。
  • 训练速度慢:启用 GPU 或混合精度训练(torch.cuda.amp)。
  • 序列长度限制:长序列可能导致内存溢出,可使用稀疏注意力或截断序列。
  • 过拟合:增加 Dropout 或正则化(如权重衰减)。

7. 参考资源

  • 论文:《Attention is All You Need》(https://arxiv.org/abs/1706.03762)
  • PyTorch 官方文档
  • nn.Transformer:https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
  • nn.MultiheadAttention:https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
  • 教程:https://pytorch.org/tutorials/beginner/transformer_tutorial.html
  • Hugging Face Transformers:https://huggingface.co/docs/transformers/index
  • 社区论坛:https://discuss.pytorch.org/

8. 进一步帮助

如果你需要更详细的 Transformer 实现(例如特定任务的模型、优化技巧、或调试代码),或者想深入探讨某个部分(如注意力机制、位置编码),请提供具体需求,我可以为你定制代码或提供更详细的解释!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注