PyTorch 构建 Transformer 模型
以下是一个使用 PyTorch 构建 Transformer 模型的详细实现,包含完整的代码示例,用于序列到序列任务(如机器翻译)。代码包括位置编码、Transformer 模型定义、数据准备和训练流程,适合初学者和需要快速上手的开发者。实现基于 torch.nn.Transformer
,并添加了必要的注释和说明。
import torch
import torch.nn as nn
import torch.optim as optim
import math
位置编码类,为输入序列添加位置信息
class PositionalEncoding(nn.Module):
def init(self, d_model, max_len=5000):
super(PositionalEncoding, self).init()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Shape: (1, max_len, d_model)
self.register_buffer(‘pe’, pe)
def forward(self, x):
# x: (batch_size, seq_len, d_model)
return x + self.pe[:, :x.size(1), :]
Transformer 模型
class TransformerModel(nn.Module):
def init(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
super(TransformerModel, self).init()
# 词嵌入
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# 位置编码
self.pos_encoder = PositionalEncoding(d_model)
# Transformer 主模块
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout
)
# 输出层
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
self.d_model = d_model
def generate_square_subsequent_mask(self, sz):
# 生成上三角掩码,防止解码器关注未来标记
mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
return mask
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
# src/tgt: (seq_len, batch_size)
src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
src_emb = self.pos_encoder(src_emb)
tgt_emb = self.pos_encoder(tgt_emb)
# Transformer 前向传播
output = self.transformer(
src_emb, tgt_emb,
src_mask=src_mask,
tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask
)
return self.fc_out(output)
数据准备
def generate_dummy_data(batch_size, seq_len, src_vocab_size, tgt_vocab_size):
src = torch.randint(1, src_vocab_size, (seq_len, batch_size)) # 源序列
tgt = torch.randint(1, tgt_vocab_size, (seq_len, batch_size)) # 目标序列
return src, tgt
训练函数
def train(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
for src, tgt in dataloader:
src, tgt = src.to(device), tgt.to(device)
optimizer.zero_grad()
# 生成目标序列掩码
tgt_input = tgt[:-1, :] # 去掉最后一个标记
tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(0)).to(device)
# 前向传播
output = model(src, tgt_input, tgt_mask=tgt_mask)
loss = criterion(output.view(-1, output.size(-1)), tgt[1:, :].view(-1))
# 反向传播和优化
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
主程序
def main():
# 超参数
src_vocab_size = 1000
tgt_vocab_size = 1000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.1
batch_size = 32
seq_len = 10
epochs = 5
# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 模型、损失函数和优化器
model = TransformerModel(
src_vocab_size, tgt_vocab_size, d_model, nhead,
num_encoder_layers, num_decoder_layers, dim_feedforward, dropout
).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略填充标记
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# 模拟数据
src, tgt = generate_dummy_data(batch_size, seq_len, src_vocab_size, tgt_vocab_size)
dataset = torch.utils.data.TensorDataset(src, tgt)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
# 训练
for epoch in range(epochs):
loss = train(model, dataloader, criterion, optimizer, device)
print(f'Epoch {epoch+1}, Loss: {loss:.4f}')
if name == ‘main‘:
main()
说明
- 代码结构:
- 位置编码:
PositionalEncoding
类实现正弦/余弦位置编码,为序列添加位置信息。 - Transformer 模型:
TransformerModel
使用nn.Transformer
,包含源/目标嵌入、位置编码和输出层。 - 掩码:
generate_square_subsequent_mask
生成解码器的掩码,防止关注未来标记。 - 训练流程:包括数据准备、前向传播、损失计算和优化。
- 运行说明:
- 代码使用模拟数据(随机生成的整数序列),实际应用中需替换为真实数据集(如通过
torchtext
加载翻译数据)。 - 模型在 GPU 上运行(如可用),否则使用 CPU。
- 超参数(如
d_model
、nhead
)可根据任务调整。
- 扩展建议:
- 真实数据集:使用
torchtext
或datasets
加载真实翻译数据(如 Multi30k)。 - 填充掩码:为变长序列添加
src_key_padding_mask
和tgt_key_padding_mask
。 - 推理:实现
generate
函数以支持自回归生成。 - 预训练:考虑使用 Hugging Face 的
transformers
加载预训练模型。
- 注意事项:
- 确保
d_model
能被nhead
整除。 - 使用
ignore_index=0
处理填充标记(padding)。 - 长序列可能需要更大的内存,建议优化批大小或使用梯度累积。
- 参考资源:
- PyTorch 官方 Transformer 教程:https://pytorch.org/tutorials/beginner/transformer_tutorial.html
nn.Transformer
文档:https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html- 《Attention is All You Need》论文:https://arxiv.org/abs/1706.03762
如果你需要针对特定任务(如翻译、分类)优化 Transformer,或需调试代码、添加特定功能(如注意力可视化),请提供更多细节,我可以进一步定制代码!