PyTorch torch.nn 参考手册
torch.nn
是 PyTorch 的核心模块之一,用于构建和训练神经网络。它提供了模块化的接口,包括层、损失函数、激活函数等,方便用户定义和组合复杂的神经网络结构。以下是 torch.nn
模块的参考手册,涵盖核心功能、常用类和方法,以及使用示例,力求简洁且全面。
1. torch.nn
模块概述
torch.nn
提供构建神经网络的工具,主要功能包括:
- 模块化设计:通过
nn.Module
类定义网络层和模型。 - 层(Layers):如全连接层、卷积层、池化层等。
- 激活函数:如 ReLU、Sigmoid 等。
- 损失函数:如交叉熵、均方误差等。
- 容器:如
nn.Sequential
用于组合多个层。 - 初始化:提供权重初始化方法。
2. 核心组件与使用方法
(1) nn.Module
nn.Module
是所有神经网络模块的基类,自定义模型通常需要继承它。
- 定义自定义模型
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
model = MyNet()
- 关键方法
forward(x)
:定义前向传播逻辑。parameters()
:返回模型的可训练参数。to(device)
:将模型移动到指定设备(CPU/GPU)。eval()
/train()
:切换模型到评估/训练模式。
(2) 常用层
- 全连接层:
nn.Linear(in_features, out_features, bias=True)
linear = nn.Linear(10, 5) # 输入10维,输出5维
- 卷积层:
nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) # 3通道输入,64通道输出,3x3卷积核
- 池化层:
nn.MaxPool2d(kernel_size, stride=None)
:最大池化nn.AvgPool2d(kernel_size, stride=None)
:平均池化
pool = nn.MaxPool2d(2, stride=2) # 2x2最大池化,步幅2
- 规范化层:
nn.BatchNorm2d(num_features)
:批归一化nn.LayerNorm(normalized_shape)
:层归一化
bn = nn.BatchNorm2d(64) # 对64通道特征进行批归一化
- Dropout:
nn.Dropout(p=0.5)
随机丢弃,防止过拟合
dropout = nn.Dropout(p=0.5) # 50%概率丢弃
(3) 激活函数
常用激活函数,位于 torch.nn
或 torch.nn.functional
:
nn.ReLU()
:max(0, x)
nn.Sigmoid()
:1 / (1 + exp(-x))
nn.Tanh()
:双曲正切nn.Softmax(dim=None)
:归一化指数函数
relu = nn.ReLU()
softmax = nn.Softmax(dim=1) # 沿第1维应用
(4) 损失函数
常用损失函数:
nn.CrossEntropyLoss()
:交叉熵损失(分类任务,包含 Softmax)
criterion = nn.CrossEntropyLoss()
nn.MSELoss()
:均方误差(回归任务)nn.BCELoss()
:二元交叉熵nn.L1Loss()
:L1 损失
mse = nn.MSELoss()
(5) 容器
nn.Sequential
:按顺序组合多个层
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
nn.ModuleList
:存储模块列表,支持动态添加
layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
nn.ModuleDict
:存储模块的字典
modules = nn.ModuleDict({'fc1': nn.Linear(10, 5), 'fc2': nn.Linear(5, 2)})
(6) 权重初始化
torch.nn.init
提供初始化方法:
import torch.nn.init as init
# 初始化线性层的权重
init.xavier_uniform_(model.fc1.weight)
init.zeros_(model.fc1.bias)
3. 完整示例:定义与训练模型
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# 简单数据集
data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
def forward(self, x):
return self.layers(x)
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(5):
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
4. 进阶功能
- 自定义层:通过继承
nn.Module
实现自定义层。
class CustomLayer(nn.Module):
def __init__(self):
super(CustomLayer, self).__init__()
self.weight = nn.Parameter(torch.randn(1))
def forward(self, x):
return x * self.weight
- 动态网络:使用
nn.ModuleList
或条件逻辑支持动态结构。 - 模型保存与加载:
torch.save(model.state_dict(), 'model.pth') # 保存
model.load_state_dict(torch.load('model.pth')) # 加载
- 冻结部分参数:
for param in model.fc1.parameters():
param.requires_grad = False # 冻结 fc1 层的参数
5. 常用功能与注意事项
- 设备一致性:确保模型和数据在同一设备(
model.to(device)
和data.to(device)
)。 - Dropout 模式:训练时启用 Dropout(
model.train()
),推理时禁用(model.eval()
)。 - 损失函数输入:
nn.CrossEntropyLoss
期望模型输出未经过 Softmax(原始 logits),标签为整数。nn.BCELoss
要求输出和标签都在[0, 1]
范围内。- 参数管理:使用
model.parameters()
获取所有可训练参数,或named_parameters()
获取参数名。 - 性能优化:结合
torch.cuda.amp
进行混合精度训练以加速计算。
6. 参考资源
- 官方文档:https://pytorch.org/docs/stable/nn.html
nn.Module
:https://pytorch.org/docs/stable/generated/torch.nn.Module.html- 损失函数:https://pytorch.org/docs/stable/nn.html#loss-functions
- 激活函数:https://pytorch.org/docs/stable/nn.html#non-linear-activations
- 教程:https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html
- GitHub 仓库:https://github.com/pytorch/pytorch
- 社区论坛:https://discuss.pytorch.org/
7. 进一步帮助
如果你需要更详细的 torch.nn
功能说明(例如特定层的参数、自定义模型设计、或优化技巧),或者有具体的任务需求(如图像分类、序列建模等),请提供更多细节,我可以为你提供更精准的代码或建议!