PyTorch Loss 函数详解
本文详细介绍 PyTorch 中常用的各类 Loss 函数,包括数学原理、代码实现和避坑指南。
一、回归任务 (Regression)
1. MSE Loss (均方误差)
原理: 计算预测值与真实值差值的平方均值。对误差大的点惩罚极重。
-
数学公式: 其中 是样本数, 是真实值, 是预测值。
-
PyTorch 实现:
import torch
import torch.nn as nn
# 假设 Batch Size = 2
predictions = torch.tensor([2.5, 0.0], requires_grad=True)
targets = torch.tensor([3.0, -0.5])
criterion = nn.MSELoss()
loss = criterion(predictions, targets)
print(loss) # ((2.5-3.0)^2 + (0.0 - (-0.5))^2) / 2 = (0.25 + 0.25) / 2 = 0.25