Transformer 模型
Transformer 模型是近年来深度学习领域的重要突破,最初由 Vaswani 等人在论文《Attention is All You Need》(2017)中提出,主要用于自然语言处理(NLP),现已广泛应用于计算机视觉、语音处理等领域。PyTorch 提供了便捷的方式通过 torch.nn
模块实现 Transformer 模型。以下是关于 Transformer 模型的详细说明,涵盖其架构、PyTorch 实现、核心组件及使用示例。
1. Transformer 模型概述
Transformer 模型摒弃了传统的循环神经网络(RNN)和卷积神经网络(CNN),完全基于注意力机制(Attention Mechanism),通过自注意力(Self-Attention)捕捉序列中元素间的关系。其核心特点包括:
- 并行化:不像 RNN 按序处理,Transformer 支持并行计算,加速训练。
- 长距离依赖:通过注意力机制有效捕捉远距离的上下文关系。
- 模块化:编码器(Encoder)和解码器(Decoder)堆叠组成,适用于序列到序列任务。
Transformer 常用于:
- 机器翻译(如英译中)
- 文本生成(如 GPT 系列)
- 分类任务(如 BERT)
- 视觉任务(如 Vision Transformer, ViT)
2. Transformer 架构
Transformer 模型由编码器(Encoder)和解码器(Decoder)组成,每部分包含多个相同的层(Layer)。
(1) 编码器(Encoder)
- 输入:序列(如单词、标记)经过嵌入(Embedding)后,添加位置编码(Positional Encoding)。
- 结构:N 个编码器层,每个层包含:
- 多头自注意力(Multi-Head Self-Attention):捕捉输入序列中各元素间的关系。
- 前馈神经网络(Feed-Forward Neural Network, FFN):逐位置应用全连接层。
- 残差连接与层归一化(Add & Norm):每子层后添加残差连接和 LayerNorm。
- 输出:编码后的特征表示,用于解码器或直接用于任务(如分类)。
(2) 解码器(Decoder)
- 输入:目标序列(偏移一位)或掩码输入。
- 结构:N 个解码器层,每个层包含:
- 掩码多头自注意力:防止关注未来的标记(用于自回归任务)。
- 多头注意力(Encoder-Decoder Attention):结合编码器输出。
- 前馈神经网络:与编码器类似。
- 残差连接与层归一化。
- 输出:生成目标序列的概率分布。
(3) 关键组件
- 注意力机制:
- Scaled Dot-Product Attention:计算查询(Query)、键(Key)、值(Value)之间的点积,缩放后通过 Softmax 归一化。
[
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
] - 多头注意力:将注意力分成多个并行头,增强表达能力。
- 位置编码:为序列添加位置信息(固定正弦/余弦函数或可学习参数)。
- 前馈网络:每个位置独立应用全连接层,通常为两层 MLP。
- 层归一化:稳定训练,加速收敛。
3. PyTorch 中的 Transformer
PyTorch 的 torch.nn
模块提供了 Transformer 的核心组件,包括 nn.Transformer
、nn.MultiheadAttention
等,方便直接使用或自定义。
(1) 核心类和函数
nn.Transformer
:完整的 Transformer 模型,包含编码器和解码器。
nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
d_model
:模型维度(嵌入和注意力维度)。nhead
:多头注意力的头数。num_encoder_layers
/num_decoder_layers
:编码器/解码器层数。nn.MultiheadAttention
:多头注意力机制。
nn.MultiheadAttention(embed_dim=512, num_heads=8)
nn.TransformerEncoder
/nn.TransformerDecoder
:单独的编码器或解码器模块。
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
- 位置编码:需手动实现或使用第三方库。
(2) 简单实现:Transformer 模型
以下是一个简单的 Transformer 模型示例,用于序列到序列任务(如翻译):
import torch
import torch.nn as nn
import math
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Shape: (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
# Transformer 模型
class SimpleTransformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_layers=6):
super(SimpleTransformer, self).__init__()
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.transformer = nn.Transformer(d_model, nhead, num_layers, num_layers)
self.fc = nn.Linear(d_model, tgt_vocab_size)
self.d_model = d_model
def forward(self, src, tgt):
src = self.src_embedding(src) * math.sqrt(self.d_model)
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
tgt = self.pos_encoder(tgt)
output = self.transformer(src, tgt)
output = self.fc(output)
return output
# 示例数据
src = torch.randint(0, 100, (10, 32)) # (seq_len, batch_size)
tgt = torch.randint(0, 100, (10, 32))
model = SimpleTransformer(src_vocab_size=100, tgt_vocab_size=100)
output = model(src, tgt)
(3) 训练流程
import torch.optim as optim
# 模型、损失函数和优化器
model = SimpleTransformer(src_vocab_size=100, tgt_vocab_size=100).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# 训练循环
model.train()
for epoch in range(5):
optimizer.zero_grad()
output = model(src.to('cuda'), tgt.to('cuda')[:-1]) # 目标序列偏移
loss = criterion(output.view(-1, 100), tgt[1:].view(-1).to('cuda'))
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
4. 注意事项
- 掩码(Mask):
- 源序列掩码:处理填充(padding)部分,
nn.Transformer
的src_mask
和src_key_padding_mask
。 - 目标序列掩码:防止解码器关注未来标记,
tgt_mask
使用上三角矩阵。python def generate_square_subsequent_mask(sz): mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) return mask
- 批处理:输入形状为
(seq_len, batch_size, d_model)
,需注意转置。 - 位置编码:必须添加,否则 Transformer 无法区分序列中元素的位置。
- 超参数:
d_model
需被nhead
整除。- 适当调整
num_layers
和dropout
(默认 0.1)以平衡性能和过拟合。
5. 进阶用法
- 预训练模型:使用 Hugging Face 的
transformers
库加载预训练 Transformer(如 BERT、GPT):
from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
- Vision Transformer (ViT):将图像分块(patch)后输入 Transformer 编码器。
- 自定义注意力:通过继承
nn.MultiheadAttention
实现变种注意力机制。 - 混合精度训练:
from torch.cuda.amp import autocast
with autocast():
output = model(src, tgt)
6. 常见问题
- 内存占用:Transformer 模型参数量大,需优化批大小或使用梯度累积。
- 训练速度慢:启用 GPU 或混合精度训练(
torch.cuda.amp
)。 - 序列长度限制:长序列可能导致内存溢出,可使用稀疏注意力或截断序列。
- 过拟合:增加 Dropout 或正则化(如权重衰减)。
7. 参考资源
- 论文:《Attention is All You Need》(https://arxiv.org/abs/1706.03762)
- PyTorch 官方文档:
nn.Transformer
:https://pytorch.org/docs/stable/generated/torch.nn.Transformer.htmlnn.MultiheadAttention
:https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html- 教程:https://pytorch.org/tutorials/beginner/transformer_tutorial.html
- Hugging Face Transformers:https://huggingface.co/docs/transformers/index
- 社区论坛:https://discuss.pytorch.org/
8. 进一步帮助
如果你需要更详细的 Transformer 实现(例如特定任务的模型、优化技巧、或调试代码),或者想深入探讨某个部分(如注意力机制、位置编码),请提供具体需求,我可以为你定制代码或提供更详细的解释!