本文档会持续更新

1 增加或删除1维

import torch
x = torch.rand(2, 3)
print(x.shape)
y1 = x.unsqueeze(1)
print(y1.shape)
y2 = x.squeeze(1)
print(y2.shape)

# output
torch.Size([2, 3])
torch.Size([2, 1, 3])
torch.Size([2, 3])