<>torch.nn.flatten

torch.nn.flatten是一个类，作用为将连续的几个维度展平成一个tensor（将一些维度合并）

*

* 开始维度默认为 1。因为其被用在神经网络中，输入为一批数据，第 0

* 结束维度默认为 -1，也就是一直合并到最后一维

*

x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten() y = F(x) print(y) print(y.
shape) >>tensor([[1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1.,
1.]]) >>torch.Size([2, 8])
*

x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten(2) y = F(x) print(y) print(y.
shape) >>tensor([[[1., 1., 1., 1.], [1., 1., 1., 1.]], [[1., 1., 1., 1.], [1., 1
., 1., 1.]]]) >>torch.Size([2, 2, 4])
*

x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten(1, 2) y = F(x) print(y) print(y
.shape) >>tensor([[[1., 1.], [1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.],
[1., 1.], [1., 1.]]]) >>torch.Size([2, 4, 2])
<>torch.flatten

t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) print(t.shape) >>torch.
Size([2, 2, 2]) print(torch.flatten(t)) >>tensor([1, 2, 3, 4, 5, 6, 7, 8]) print
(torch.flatten(t, 1)) >>tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) print(torch.flatten
(t, 0, 1).shape) >>torch.Size([4, 2])

t = torch.tensor(1) print("before flatten:") print(t) print(t.shape) >>before
flatten: tensor(1) torch.Size([]) print("\n") print("after flatten:") print(
torch.flatten(t)) print(torch.flatten(t).shape) >>after flatten: tensor([1])
torch.Size([1])

GitHub

Gitee