pytorch中tensor进行reshape操作后原始数据的顺序

张开发
2026/4/6 21:16:26 15 分钟阅读

分享文章

pytorch中tensor进行reshape操作后原始数据的顺序
在pytorch中经常需要对tensor进行reshape操作使其符合特定网络的输入格式。在将网络的输出重新reshape回输入前的形状时tensor的特征是否还是按输入的顺序进行排列带着疑问做了下面的实验x1 torch.randn(2, 3) x2 torch.randn(2, 3) x3 torch.randn(2, 3) x4 torch.stack((x1, x2, x3), 0) shape x4.shape print(x4:, x4.shape) print(x4:\n, x4) x4 x4.reshape(x4.shape[0]*x4.shape[1], x4.shape[-1]) print(reshaped x4:, x4.shape) print(reshaped x4:\n, x4) x4 x4.reshape(shape[0], shape[1], shape[-1]) print(recovered x4:\n, x4, x4.shape) # print(x5:\n, x5)输出x4: torch.Size([3, 2, 3]) x4: tensor([[[-1.2061, 0.0617, 1.1632], [-1.5008, -1.5944, -0.0187]], [[-2.1325, -0.5270, -0.1021], [ 0.0099, -0.4454, -1.4976]], [[-0.9475, -0.6130, -0.1291], [-0.4107, 1.3931, -0.0984]]]) reshaped x4: torch.Size([6, 3]) reshaped x4: tensor([[-1.2061, 0.0617, 1.1632], [-1.5008, -1.5944, -0.0187], [-2.1325, -0.5270, -0.1021], [ 0.0099, -0.4454, -1.4976], [-0.9475, -0.6130, -0.1291], [-0.4107, 1.3931, -0.0984]]) recovered x4: tensor([[[-1.2061, 0.0617, 1.1632], [-1.5008, -1.5944, -0.0187]], [[-2.1325, -0.5270, -0.1021], [ 0.0099, -0.4454, -1.4976]], [[-0.9475, -0.6130, -0.1291], [-0.4107, 1.3931, -0.0984]]]) torch.Size([3, 2, 3])将x1, x2和x3三个tensor通过stack操作堆到一起后通过reshape操作改变维度的形状接着再将reshape完的tensor变回原来的形状发现输出数据的顺序和改变形状之前相同表明在reshape过程中tensor能够保持数据的顺序

更多文章