人工智能:注意力机制与Transformer模型实战

好的,我们来一场注意力机制 → Transformer 的实战向讲解。

目标人群定位:已经了解过神经网络基本概念(MLP、embedding、RNN/LSTM大概知道),但对Transformer还处于“听过很多遍但自己写不出来”的阶段。

我们今天尽量用最直白 + 可运行的方式,把核心讲透 + 写出来。

一、先把“注意力”这件事彻底想明白(核心直觉)

一句话总结注意力机制最本质的作用:

让序列中每一个位置都能“按需”去看全序列中所有其他位置的信息,而且看多少是由数据自己学出来的。

传统RNN/LSTM是按顺序一个一个看,注意力是“全局视野 + 软寻址”。

最经典的Scaled Dot-Product Attention 计算公式(一定要背下来):

Attention(Q, K, V) = softmax( (Q K^T) / √d_k ) V
  • Q:Query 我现在关心什么?
  • K:Key 别人介绍自己用的标签
  • V:Value 别人真正携带的信息

类比一次人际关系版:

你在相亲(当前token是Q),对面100个人,每人举着一张标签(K),你根据标签跟你需求的相关度打分(Q·K),然后把打分softmax变成权重,最后把每个人的真心话(V)按权重加权求和 → 得到你听到的“综合意见”。

二、最小可运行的单头注意力(PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SimpleSelfAttention(nn.Module):
    def __init__(self, d_model, d_k=None):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k if d_k is not None else d_model

        # 三个可学习投影矩阵
        self.Wq = nn.Linear(d_model, self.d_k)
        self.Wk = nn.Linear(d_model, self.d_k)
        self.Wv = nn.Linear(d_model, d_model)  # V的维度通常保持d_model

    def forward(self, x, mask=None):
        # x shape: [batch, seq_len, d_model]
        Q = self.Wq(x)     # [b, s, d_k]
        K = self.Wk(x)
        V = self.Wv(x)     # [b, s, d_model]

        # 核心计算:注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1))   # [b, s, s]
        scores = scores / math.sqrt(self.d_k)           # Scaled

        # mask(可选,用于因果/填充)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)        # [b, s, s]
        out = torch.matmul(attn_weights, V)             # [b, s, d_model]

        return out, attn_weights

测试一下:

torch.manual_seed(42)
x = torch.randn(2, 8, 64)          # batch=2, seq=8, dim=64
attn = SimpleSelfAttention(64)
out, weights = attn(x)

print(weights.shape)         # torch.Size([2, 8, 8])
print(out.shape)             # torch.Size([2, 8, 64])

三、升级版:多头注意力(Multi-Head Attention)

为什么需要多头?一句话:让模型同时从多个子空间观察关系(语法、语义、指代、情感……)

实现方式:把d_model切成h份,每份独立算单头,最后拼接+线性投影。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)   # 最后的融合

    def forward(self, x, mask=None):
        b, s, d = x.shape

        Q = self.Wq(x).view(b, s, self.num_heads, self.d_k).transpose(1,2)
        K = self.Wk(x).view(b, s, self.num_heads, self.d_k).transpose(1,2)
        V = self.Wv(x).view(b, s, self.num_heads, self.d_k).transpose(1,2)
        # 现在形状:[b, h, s, d_k]

        scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k)

        if mask is not None:
            # mask要广播到 [b, h, s, s]
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)                     # [b,h,s,d_k]

        # 合并多头
        context = context.transpose(1,2).contiguous().view(b, s, d)
        out = self.Wo(context)

        return out

四、一个最简Transformer Encoder Layer(实战最常用)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 子层1:自注意力 + 残差 + norm
        attn_out = self.self_attn(x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        # 子层2:前馈网络 + 残差 + norm
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))

        return x

六层堆叠就是经典Transformer Encoder了(BERT就是这种结构)。

五、快速对比三种最常见的注意力mask使用场景(2026年仍最主流)

场景mask类型谁能看到谁典型模型
Encoder(BERT)padding mask所有有效token互看BERT, RoBERTa
Decoder自回归生成causal mask只能看到自己及之前的tokenGPT, LLaMA, Qwen
Encoder-Decoderpadding + crossDecoder每个位置看Encoder全部T5, BART, Transformer翻译

因果mask(causal mask)代码实现最常用写法:

def generate_causal_mask(seq_len, device):
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    return (mask == 0).float()   # 上三角为0(不可见),下三角为1(可见)

六、下一阶段推荐动手方向(由浅入深)

  1. 用上面代码实现一个小型GPT(只用decoder,causal mask),在tiny Shakespeare上语言建模
  2. 加入位置编码(推荐用Rotary Embedding / RoPE,2025-2026主流)
  3. 实现Group Query Attention (GQA)Multi-Query Attention (MQA)(Llama3、Qwen2常用)
  4. torch.compile() 或 flash-attention-2 加速训练(速度可提升2-4倍)

有想重点深挖的模块吗?
比如:

  • 想直接跑一个最小GPT训练?
  • 想看Rotary位置编码代码?
  • 想看Flash Attention的原理与降内存方式?
  • 还是想直接讨论2026年最新的高效注意力变体(MLA、Mamba-2、RetNet等)?

告诉我你的具体目标,我可以继续陪你写/调/测。

文章已创建 4915

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

相关文章

开始在上面输入您的搜索词,然后按回车进行搜索。按ESC取消。

返回顶部