(张量)，在处理问题时，需要经常改变数据的维度，以便于后期的计算和进一步处理，本文旨在列举一些维度变换的方法并举例，方便大家查看。

<>维度查看：torch.Tensor.size()

>>> import torch >>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]) >>> a.size()
torch.Size([1, 3, 2])
<>张量变形：torch.Tensor.view(*args) → Tensor

tensor 必须是连续的contiguous() 才能被查看。

>>> x = torch.randn(2, 9) >>> x.size() torch.Size([2, 9]) >>> x tensor([[-
1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038, 0.5166, 0.9774, 0.3455], [-
0.2306, 0.4217, 1.2874, -0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]]) >>
> y = x.view(3, 6) >>> y.size() torch.Size([3, 6]) >>> y tensor([[-1.6833, -
0.4100, -1.5534, -0.6229, -1.0310, -0.8038], [ 0.5166, 0.9774, 0.3455, -0.2306,
0.4217, 1.2874], [-0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]]) >>> z =
x.view(2, 3, 3) >>> z.size() torch.Size([2, 3, 3]) >>> z tensor([[[-1.6833, -
0.4100, -1.5534], [-0.6229, -1.0310, -0.8038], [ 0.5166, 0.9774, 0.3455]], [[-
0.2306, 0.4217, 1.2874], [-0.3618, 1.7872, -0.9012], [ 0.8073, -1.1238, -0.3405]
]])

<>压缩 / 解压张量：torch.squeeze()、torch.unsqueeze()

* torch.squeeze(input, dim=None, out=None)

squeeze(input, 1)，形状会变成 (A×B)。

>>> x = torch.randn(3, 1, 2) >>> x tensor([[[-0.1986, 0.4352]], [[ 0.0971,
0.2296]], [[ 0.8339, -0.5433]]]) >>> x.squeeze().size() # 不加参数，去掉所有为元素个数为1的维度
torch.Size([3, 2]) >>> x.squeeze() tensor([[-0.1986, 0.4352], [ 0.0971, 0.2296],
[ 0.8339, -0.5433]]) >>> torch.squeeze(x, 0).size() #

() # 加上参数，去掉第二维的元素，正好为 1，起作用 torch.Size([3, 2])

* torch.unsqueeze(input, dim, out=None)

>>> x.unsqueeze(0).size() torch.Size([1, 3, 1, 2]) >>> x.unsqueeze(0) tensor([[
[[-0.1986, 0.4352]], [[ 0.0971, 0.2296]], [[ 0.8339, -0.5433]]]]) >>> x.
unsqueeze(-1).size() torch.Size([3, 1, 2, 1]) >>> x.unsqueeze(-1) tensor([[[[-
0.1986], [ 0.4352]]], [[[ 0.0971], [ 0.2296]]], [[[ 0.8339], [-0.5433]]]])

<>扩大张量：torch.Tensor.expand(*sizes) → Tensor

0，一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。

>>> x = torch.Tensor([[1], [2], [3]]) >>> x.size() torch.Size([3, 1]) >>> x.
expand(3, 4) tensor([[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]]) >>>
x.expand(3, -1) tensor([[1.], [2.], [3.]])

1 的维度必须要和原来一样填写进去。

<>重复张量：torch.Tensor.repeat(*sizes)

>>> x = torch.Tensor([1, 2, 3]) >>> x.size() torch.Size([3]) >>> x.repeat(4, 2)
[1., 2., 3., 1., 2., 3.], [1., 2., 3., 1., 2., 3.], [1., 2., 3., 1., 2., 3.]])
>>> x.repeat(4, 2).size() torch.Size([4, 6])

<>矩阵转置：torch.t(input, out=None) → Tensor

>>> x = torch.randn(3, 5) >>> x tensor([[-1.0752, -0.9706, -0.8770, -0.4224,
0.9776], [ 0.2489, -0.2986, -0.7816, -0.0823, 1.1811], [-1.1124, 0.2160, -0.8446
, 0.1762, -0.5164]]) >>> x.t() tensor([[-1.0752, 0.2489, -1.1124], [-0.9706, -
0.2986, 0.2160], [-0.8770, -0.7816, -0.8446], [-0.4224, -0.0823, 0.1762], [
0.9776, 1.1811, -0.5164]]) >>> torch.t(x) # 另一种用法 tensor([[-1.0752, 0.2489, -
1.1124], [-0.9706, -0.2986, 0.2160], [-0.8770, -0.7816, -0.8446], [-0.4224, -
0.0823, 0.1762], [ 0.9776, 1.1811, -0.5164]])

<>维度置换：torch.transpose()、torch.Tensor.permute()

* torch.transpose(input, dim0, dim1, out=None) → Tensor

>>> x = torch.randn(2, 4, 3) >>> x tensor([[[-1.2502, -0.7363, 0.5534], [-
0.2050, 3.1847, -1.6729], [-0.2591, -0.0860, 0.4660], [-1.2189, -1.1206, 0.0637]
], [[ 1.4791, -0.7569, 2.5017], [ 0.0098, -1.0217, 0.8142], [-0.2414, -0.1790,
2.3506], [-0.6860, -0.2363, 1.0481]]]) >>> torch.transpose(x, 1, 2).size() torch
.Size([2, 3, 4]) >>> torch.transpose(x, 1, 2) tensor([[[-1.2502, -0.2050, -
0.2591, -1.2189], [-0.7363, 3.1847, -0.0860, -1.1206], [ 0.5534, -1.6729, 0.4660
, 0.0637]], [[ 1.4791, 0.0098, -0.2414, -0.6860], [-0.7569, -1.0217, -0.1790, -
0.2363], [ 2.5017, 0.8142, 2.3506, 1.0481]]]) >>> torch.transpose(x, 0, 1).size(
) torch.Size([4, 2, 3]) >>> torch.transpose(x, 0, 1) tensor([[[-1.2502, -0.7363,
0.5534], [ 1.4791, -0.7569, 2.5017]], [[-0.2050, 3.1847, -1.6729], [ 0.0098, -
1.0217, 0.8142]], [[-0.2591, -0.0860, 0.4660], [-0.2414, -0.1790, 2.3506]], [[-
1.2189, -1.1206, 0.0637], [-0.6860, -0.2363, 1.0481]]])

* torch.Tensor.permute(dims)

>>> x.size() torch.Size([2, 4, 3]) >>> x.permute(2, 0, 1).size() torch.Size([3,
2, 4]) >>> x.permute(2, 0, 1) tensor([[[-1.2502, -0.2050, -0.2591, -1.2189], [
1.4791, 0.0098, -0.2414, -0.6860]], [[-0.7363, 3.1847, -0.0860, -1.1206], [-
0.7569, -1.0217, -0.1790, -0.2363]], [[ 0.5534, -1.6729, 0.4660, 0.0637], [
2.5017, 0.8142, 2.3506, 1.0481]]])

GitHub

Gitee