组合之torch.cat()和 torch.stack()

    技术2023-08-07  84

    cat即concatenate的意思,是指沿着已有的数据的 某一维度进行拼接,操作后数据的总维数不变,在进行拼接时,除了拼 接的维度之外,其他维度必须相同。 而torch.stack()函数指新增维度,并按照指定的维度进行叠加

    import torch a = torch.randperm(10) a = a.reshape([2, 5]) b = torch.Tensor([[10, 11, 15, 12, 13], [5, 4, 3, 2, 1]]) a = a.type_as(b) # 将a,b数据类型转换一致 c = a + b # 不转换数据类型,print(a+b)会报错 torch.cat([a,b]) # torch.cat([a,b], 0) Out[5]: tensor([[ 6., 9., 8., 1., 4.], [ 5., 2., 3., 0., 7.], [10., 11., 15., 12., 13.], [ 5., 4., 3., 2., 1.]]) torch.cat([a,b],1) Out[6]: tensor([[ 6., 9., 8., 1., 4., 10., 11., 15., 12., 13.], [ 5., 2., 3., 0., 7., 5., 4., 3., 2., 1.]]) # 以第0维进行stack,叠加的基本单位为序列本身,即a与b,因此输出[a, b], torch.stack([a,b], 0) Out[7]: tensor([[[ 6., 9., 8., 1., 4.], [ 5., 2., 3., 0., 7.]], [[10., 11., 15., 12., 13.], [ 5., 4., 3., 2., 1.]]]) # 以第1维进行stack,叠加的基本单位为每一行 torch.stack([a,b], 1) Out[64]: tensor([[[ 6., 9., 8., 1., 4.], [10., 11., 15., 12., 13.]], [[ 5., 2., 3., 0., 7.], [ 5., 4., 3., 2., 1.]]]) # 以第2维进行stack,叠加的基本单位为每一行的每一个元素 torch.stack([a,b], 2) Out[67]: tensor([[[ 6., 10.], [ 9., 11.], [ 8., 15.], [ 1., 12.], [ 4., 13.]], [[ 5., 5.], [ 2., 4.], [ 3., 3.], [ 0., 2.], [ 7., 1.]]])
    Processed: 0.009, SQL: 10