PyTorch: view vs reshape 与连续性
1. 什么是连续性(Contiguous)
核心概念
连续性 = tensor在内存中的存储 顺序和逻辑顺序一致
直观理解
import torch
# 创建一个 2x3 的tensor
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print("逻辑视图(你看到的):")
# [[1, 2, 3],
# [4, 5, 6]]
print("物理内存(实际存储):")
# [1, 2, 3, 4, 5, 6] ← 连续存储,一个接一个
这就是连续的:从左到右、从上到下读取时,内存中的数据也是这个顺序。
转置后变成不连续
y = x.t() # 转置
print("逻辑视图:")
# [[1, 4],
# [2, 5],
# [3, 6]]
print("物理内存(未改变!):")
# [1, 2, 3, 4, 5, 6] ← 还是原来的顺序
print(y.is_contiguous()) # False
关键点:转置后,PyTorch并没有重新排列内存中的数据,而是通过改变**stride(步长)**来改变访问方式。
Stride(步长)详解
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(x.stride()) # (3, 1)
# 含义:
# - 行方向移动1步 → 内存地址 +3
# - 列方向移动1步 → 内存地址 +1
y = x.t()
print(y.stride()) # (1, 3)
# 含义:
# - 行方向移动1步 → 内存地址 +1
# - 列方向移动1步 → 内存地址 +3
图示
原始 x (连续):
逻辑: 内存:
[1 2 3] [1][2][3][4][5][6]
[4 5 6] ↓ ↓ ↓ ↓ ↓ ↓
顺序访问
转置 y (不连续):
逻辑: 内存:
[1 4] [1][2][3][4][5][6]
[2 5] ↓ ↓ (跳着访问)
[3 6] ↓ ↓
↓ ↓
哪些操作会导致不连续
x = torch.randn(2, 3, 4)
# 导致不连续:
y1 = x.transpose(0, 1) # 转置
y2 = x.permute(2, 0, 1) # 重排维度
y3 = x[:, :, ::2] # 跳步切片
y4 = x.narrow(1, 0, 2) # narrow操作
# 保持连续:
z1 = x + 1 # 数学运算(创建新tensor)
z2 = x.clone() # 克隆
z3 = x.reshape(...) # reshape会自动处理
2. view vs reshape
核心区别
| 特性 | view | reshape |
|---|---|---|
| 连续性要求 | 必须连续,否则报错 | 自动处理,不连续时会复制 |
| 返回值 | 总是返回视图(共享内存) | 可能返回视图或副本 |
| 速度 | 快(不复制数据) | 不连续时较慢(需复制) |
| 安全性 | 严格,问题立即暴露 | 宽松,可能隐藏性能问题 |
代码示例
示例1:连续tensor
x = torch.randn(2, 3)
# 两者都可以工作,且都返回视图
y1 = x.view(6)
y2 = x.reshape(6)
# 都共享内存
y1[0] = 999
print(x[0, 0]) # 999
y2[0] = 888
print(x[0, 0]) # 888
示例2:不连续tensor
x = torch.randn(2, 3)
y = x.t() # 转置,不连续
# view: 报错
try:
z = y.view(6)
except RuntimeError as e:
print("view报错:", e)
# reshape: 自动复制数据,成功
z = y.reshape(6) # OK,但复制了数据
print(z.shape) # torch.Size([6])