PyTorch 构建 Transformer 模型

以下是一个使用 PyTorch 构建 Transformer 模型的详细实现,包含完整的代码示例,用于序列到序列任务(如机器翻译)。代码包括位置编码、Transformer 模型定义、数据准备和训练流程,适合初学者和需要快速上手的开发者。实现基于 torch.nn.Transformer,并添加了必要的注释和说明。


import torch
import torch.nn as nn
import torch.optim as optim
import math

位置编码类,为输入序列添加位置信息

class PositionalEncoding(nn.Module):
def init(self, d_model, max_len=5000):
super(PositionalEncoding, self).init()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Shape: (1, max_len, d_model)
self.register_buffer(‘pe’, pe)

def forward(self, x):
    # x: (batch_size, seq_len, d_model)
    return x + self.pe[:, :x.size(1), :]

Transformer 模型

class TransformerModel(nn.Module):
def init(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
super(TransformerModel, self).init()
# 词嵌入
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# 位置编码
self.pos_encoder = PositionalEncoding(d_model)
# Transformer 主模块
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout
)
# 输出层
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
self.d_model = d_model

def generate_square_subsequent_mask(self, sz):
    # 生成上三角掩码,防止解码器关注未来标记
    mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    return mask

def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
    # src/tgt: (seq_len, batch_size)
    src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
    tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
    src_emb = self.pos_encoder(src_emb)
    tgt_emb = self.pos_encoder(tgt_emb)

    # Transformer 前向传播
    output = self.transformer(
        src_emb, tgt_emb,
        src_mask=src_mask,
        tgt_mask=tgt_mask,
        src_key_padding_mask=src_key_padding_mask,
        tgt_key_padding_mask=tgt_key_padding_mask
    )
    return self.fc_out(output)

数据准备

def generate_dummy_data(batch_size, seq_len, src_vocab_size, tgt_vocab_size):
src = torch.randint(1, src_vocab_size, (seq_len, batch_size)) # 源序列
tgt = torch.randint(1, tgt_vocab_size, (seq_len, batch_size)) # 目标序列
return src, tgt

训练函数

def train(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
for src, tgt in dataloader:
src, tgt = src.to(device), tgt.to(device)
optimizer.zero_grad()

    # 生成目标序列掩码
    tgt_input = tgt[:-1, :]  # 去掉最后一个标记
    tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(0)).to(device)

    # 前向传播
    output = model(src, tgt_input, tgt_mask=tgt_mask)
    loss = criterion(output.view(-1, output.size(-1)), tgt[1:, :].view(-1))

    # 反向传播和优化
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
return total_loss / len(dataloader)

主程序

def main():
# 超参数
src_vocab_size = 1000
tgt_vocab_size = 1000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.1
batch_size = 32
seq_len = 10
epochs = 5

# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 模型、损失函数和优化器
model = TransformerModel(
    src_vocab_size, tgt_vocab_size, d_model, nhead, 
    num_encoder_layers, num_decoder_layers, dim_feedforward, dropout
).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充标记
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 模拟数据
src, tgt = generate_dummy_data(batch_size, seq_len, src_vocab_size, tgt_vocab_size)
dataset = torch.utils.data.TensorDataset(src, tgt)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

# 训练
for epoch in range(epochs):
    loss = train(model, dataloader, criterion, optimizer, device)
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}')

if name == ‘main‘:
main()


说明

  1. 代码结构
  • 位置编码PositionalEncoding 类实现正弦/余弦位置编码,为序列添加位置信息。
  • Transformer 模型TransformerModel 使用 nn.Transformer,包含源/目标嵌入、位置编码和输出层。
  • 掩码generate_square_subsequent_mask 生成解码器的掩码,防止关注未来标记。
  • 训练流程:包括数据准备、前向传播、损失计算和优化。
  1. 运行说明
  • 代码使用模拟数据(随机生成的整数序列),实际应用中需替换为真实数据集(如通过 torchtext 加载翻译数据)。
  • 模型在 GPU 上运行(如可用),否则使用 CPU。
  • 超参数(如 d_modelnhead)可根据任务调整。
  1. 扩展建议
  • 真实数据集:使用 torchtextdatasets 加载真实翻译数据(如 Multi30k)。
  • 填充掩码:为变长序列添加 src_key_padding_masktgt_key_padding_mask
  • 推理:实现 generate 函数以支持自回归生成。
  • 预训练:考虑使用 Hugging Face 的 transformers 加载预训练模型。
  1. 注意事项
  • 确保 d_model 能被 nhead 整除。
  • 使用 ignore_index=0 处理填充标记(padding)。
  • 长序列可能需要更大的内存,建议优化批大小或使用梯度累积。
  1. 参考资源
  • PyTorch 官方 Transformer 教程:https://pytorch.org/tutorials/beginner/transformer_tutorial.html
  • nn.Transformer 文档:https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
  • 《Attention is All You Need》论文:https://arxiv.org/abs/1706.03762

如果你需要针对特定任务(如翻译、分类)优化 Transformer,或需调试代码、添加特定功能(如注意力可视化),请提供更多细节,我可以进一步定制代码!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注