好的,我们来一场注意力机制 → Transformer 的实战向讲解。
目标人群定位:已经了解过神经网络基本概念(MLP、embedding、RNN/LSTM大概知道),但对Transformer还处于“听过很多遍但自己写不出来”的阶段。
我们今天尽量用最直白 + 可运行的方式,把核心讲透 + 写出来。
一、先把“注意力”这件事彻底想明白(核心直觉)
一句话总结注意力机制最本质的作用:
让序列中每一个位置都能“按需”去看全序列中所有其他位置的信息,而且看多少是由数据自己学出来的。
传统RNN/LSTM是按顺序一个一个看,注意力是“全局视野 + 软寻址”。
最经典的Scaled Dot-Product Attention 计算公式(一定要背下来):
Attention(Q, K, V) = softmax( (Q K^T) / √d_k ) V
- Q:Query 我现在关心什么?
- K:Key 别人介绍自己用的标签
- V:Value 别人真正携带的信息
类比一次人际关系版:
你在相亲(当前token是Q),对面100个人,每人举着一张标签(K),你根据标签跟你需求的相关度打分(Q·K),然后把打分softmax变成权重,最后把每个人的真心话(V)按权重加权求和 → 得到你听到的“综合意见”。
二、最小可运行的单头注意力(PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SimpleSelfAttention(nn.Module):
def __init__(self, d_model, d_k=None):
super().__init__()
self.d_model = d_model
self.d_k = d_k if d_k is not None else d_model
# 三个可学习投影矩阵
self.Wq = nn.Linear(d_model, self.d_k)
self.Wk = nn.Linear(d_model, self.d_k)
self.Wv = nn.Linear(d_model, d_model) # V的维度通常保持d_model
def forward(self, x, mask=None):
# x shape: [batch, seq_len, d_model]
Q = self.Wq(x) # [b, s, d_k]
K = self.Wk(x)
V = self.Wv(x) # [b, s, d_model]
# 核心计算:注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) # [b, s, s]
scores = scores / math.sqrt(self.d_k) # Scaled
# mask(可选,用于因果/填充)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1) # [b, s, s]
out = torch.matmul(attn_weights, V) # [b, s, d_model]
return out, attn_weights
测试一下:
torch.manual_seed(42)
x = torch.randn(2, 8, 64) # batch=2, seq=8, dim=64
attn = SimpleSelfAttention(64)
out, weights = attn(x)
print(weights.shape) # torch.Size([2, 8, 8])
print(out.shape) # torch.Size([2, 8, 64])
三、升级版:多头注意力(Multi-Head Attention)
为什么需要多头?一句话:让模型同时从多个子空间观察关系(语法、语义、指代、情感……)
实现方式:把d_model切成h份,每份独立算单头,最后拼接+线性投影。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model) # 最后的融合
def forward(self, x, mask=None):
b, s, d = x.shape
Q = self.Wq(x).view(b, s, self.num_heads, self.d_k).transpose(1,2)
K = self.Wk(x).view(b, s, self.num_heads, self.d_k).transpose(1,2)
V = self.Wv(x).view(b, s, self.num_heads, self.d_k).transpose(1,2)
# 现在形状:[b, h, s, d_k]
scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k)
if mask is not None:
# mask要广播到 [b, h, s, s]
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
context = torch.matmul(attn, V) # [b,h,s,d_k]
# 合并多头
context = context.transpose(1,2).contiguous().view(b, s, d)
out = self.Wo(context)
return out
四、一个最简Transformer Encoder Layer(实战最常用)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 子层1:自注意力 + 残差 + norm
attn_out = self.self_attn(x, mask)
x = self.norm1(x + self.dropout(attn_out))
# 子层2:前馈网络 + 残差 + norm
ff_out = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_out))
return x
六层堆叠就是经典Transformer Encoder了(BERT就是这种结构)。
五、快速对比三种最常见的注意力mask使用场景(2026年仍最主流)
| 场景 | mask类型 | 谁能看到谁 | 典型模型 |
|---|---|---|---|
| Encoder(BERT) | padding mask | 所有有效token互看 | BERT, RoBERTa |
| Decoder自回归生成 | causal mask | 只能看到自己及之前的token | GPT, LLaMA, Qwen |
| Encoder-Decoder | padding + cross | Decoder每个位置看Encoder全部 | T5, BART, Transformer翻译 |
因果mask(causal mask)代码实现最常用写法:
def generate_causal_mask(seq_len, device):
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
return (mask == 0).float() # 上三角为0(不可见),下三角为1(可见)
六、下一阶段推荐动手方向(由浅入深)
- 用上面代码实现一个小型GPT(只用decoder,causal mask),在tiny Shakespeare上语言建模
- 加入位置编码(推荐用Rotary Embedding / RoPE,2025-2026主流)
- 实现Group Query Attention (GQA) 或 Multi-Query Attention (MQA)(Llama3、Qwen2常用)
- 用
torch.compile()或 flash-attention-2 加速训练(速度可提升2-4倍)
有想重点深挖的模块吗?
比如:
- 想直接跑一个最小GPT训练?
- 想看Rotary位置编码代码?
- 想看Flash Attention的原理与降内存方式?
- 还是想直接讨论2026年最新的高效注意力变体(MLA、Mamba-2、RetNet等)?
告诉我你的具体目标,我可以继续陪你写/调/测。