Pytorch:expand

    技术2022-07-10  131

    expand的含义:为1的维度可以变大维度或者维度数增多 tensor_1.expand(size):把tensor_1扩展成size的形状 tensor_1.expand_as(tensor_2) :把tensor_1扩展成和tensor_2一样的形状

    import torch #1 x = torch.randn(2, 1, 1)#为1可以扩展为3和4 x = x.expand(2, 3, 4) print('x :', x.size()) >>> x : torch.Size([2, 3, 4]) #2 #扩展一个新的维度必须在最前面,否则会报错 x = x.expand(2, 3, 4, 6) >>> RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 1. x = x.expand(6, 2, 3, 4) >>> x : torch.Size([6, 2, 3, 4]) #3 #某一个维度为-1表示不改变该维度的大小 x = x.expand(6, -1, -1, -1) >>> x : torch.Size([6, 2, 1, 1])

    import torch #1 x = torch.randn(2, 1, 1)#为1可以扩展为3和3 y = torch.randn(2, 3, 3) x = x.expand_as(y) print('x :', x.size()) >>> x : torch.Size([2, 3, 3]) #2 x = torch.randn(2, 2, 2)#为2不可以扩展为3和4 y = torch.randn(2, 3, 4) x = x.expand_as(y) print('x :', x.size()) >>> RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 2. Target sizes: [2, 3, 4].
    Processed: 0.127, SQL: 9