• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

Pytorch permute / transpose 和 view / reshape, flatten函数

武飞扬头像
ytusdc
帮助1

1、transpose与permute

transpose() 和 permute() 都是返回转置后矩阵,在pytorch中转置用的函数就只有这两个 ,这两个函数都是交换维度的操作

transpose用法:tensor.transpose(dim0, dim1) → Tensor
只能操作2D矩阵的转置, transpose每次只能交换两个维度, 这是相比于permute的一个不同点,每次输入两个index,实现转置,,参数顺序无所谓。
permute用法:tensor.permute(dim0, dim1, ...., dimn)
permute可以进行多维度转置, permute每次可以交换多个维度,且必须传入所有维度数,参数顺序表示交换结果是原值的哪个维。

permute操作可以有1至多步的Transpose操作实现

注意:使用transpose或permute之后,若要使用view,必须先contiguous()

  1.  
    # 创造二维数据x,dim=0时候2,dim=1时候3
  2.  
    x = torch.randn(2,3) 'x.shape → [2,3]'
  3.  
    # 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
  4.  
    y = torch.randn(2,3,4) 'y.shape → [2,3,4]'
  5.  
     
  6.  
    """
  7.  
    操作dim不同:
  8.  
    transpose()只能一次操作两个维度;permute()可以一次操作多维数据,
  9.  
    且必须传入所有维度数,因为permute()的参数是int*。
  10.  
    """
  11.  
     
  12.  
    # 对于transpose
  13.  
    x.transpose(0,1) 'shape→[3,2] '
  14.  
    x.transpose(1,0) 'shape→[3,2] '
  15.  
    y.transpose(0,1) 'shape→[3,2,4]'
  16.  
    y.transpose(0,2,1) 'error,操作不了多维'
  17.  
     
  18.  
    # 对于permute()
  19.  
    x.permute(0,1) 'shape→[2,3]'
  20.  
    x.permute(1,0) 'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
  21.  
    y.permute(0,1) "error 没有传入所有维度数"
  22.  
    y.permute(1,0,2) 'shape→[3,2,4]'
  23.  
     
  24.  
    """
  25.  
    操作dim不同:
  26.  
    transpose()只能一次操作两个维度, 维度的顺序不影响结果;permute()可以一次操作多维数据,
  27.  
    且必须传入所有维度数,因为permute()的参数是int*。
  28.  
    """
  29.  
    # 对于transpose, (0,1) 和 (1,0) 都是指变换 维度 0 和 1, 输入顺序不影响
  30.  
    x1 = x.transpose(0,1) 'shape→[3,2] '
  31.  
    x2 = x.transpose(1,0) '也变换了,shape→[3,2] '
  32.  
     
  33.  
    # 对于permute(),
  34.  
    x1 = x.permute(0,1) '保持原理tensor不变, 不同transpose,shape→[2,3] '
  35.  
    x2 = x.permute(1,0) 'shape→[3,2] '
  36.  
     
  37.  
    y1 = y.permute(0,1,2) '保持不变,shape→[2,3,4] '
  38.  
    y2 = y.permute(1,0,2) 'shape→[3,2,4] '
  39.  
    y3 = y.permute(1,2,0) 'shape→[3,4,2] '
学新通

2、关于连续contiguous()

用view()函数改变通过转置后的数据结构,导致报错

RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

这是因为tensor经过转置后数据的内存地址不连续导致的,也就是tensor . is_contiguous()==False。
虽然在torch里面,view函数相当于numpy的reshape,但是这时候reshape()可以改变该tensor结构,但是view()不可以

  1.  
    x = torch.rand(3,4)
  2.  
    x = x.transpose(0,1)
  3.  
    print(x.is_contiguous()) # 是否连续
  4.  
    'False'
  5.  
    # 会发现
  6.  
    x.view(3,4)
  7.  
    '''
  8.  
    RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
  9.  
    就是不连续导致的
  10.  
    '''
  11.  
    # 但是这样是可以的。
  12.  
    x = x.contiguous()
  13.  
    x.view(3,4)
  14.  
     
  15.  
    x = torch.rand(3,4)
  16.  
    x = x.permute(1,0) # 等价x = x.transpose(0,1)
  17.  
    x.reshape(3,4)
  18.  
    '''这就不报错了
  19.  
    说明x.reshape(3,4) 这个操作
  20.  
    等于x = x.contiguous().view()
  21.  
    尽管如此,但是我们还是不推荐使用reshape
  22.  
    除非为了获取完全不同但是数据相同的克隆体
  23.  
    '''
学新通

调用contiguous()时,会强制拷贝一份tensor,让它的布局从头到尾创建的一毛一样。
只需要记住了,每次在使用view()之前,该tensor只要使用了transpose()和permute()这两个函数一定要contiguous().

transpose与permute会实实在在的根据需求(要交换的dim)把相应的Tensor元素的位置进行调整, 而view 会将Tensor所有维度拉平成一维 (即按行,这也是为什么view操作要求Tensor是contiguous的原因),然后再根据传入的的维度(只要保证各维度的乘积=总元素个数即可)信息重构出一个Tensor。

  1.  
    a = torch.Tensor([[[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]],
  2.  
    [[-1,-2,-3,-4,-5], [-6,-7,-8,-9,-10], [-11,-12,-13,-14,-15]]])
  3.  
    >>> a.shape
  4.  
    torch.Size([2, 3, 5])
  5.  
    # 还是上面的Tensor a
  6.  
    >>> print(a.shape)
  7.  
    torch.Size([2, 3, 5])
  8.  
    >>> print(a.view(2,5,3))
  9.  
    tensor([[[ 1., 2., 3.],
  10.  
    [ 4., 5., 6.],
  11.  
    [ 7., 8., 9.],
  12.  
    [ 10., 11., 12.],
  13.  
    [ 13., 14., 15.]],
  14.  
     
  15.  
    [[ -1., -2., -3.],
  16.  
    [ -4., -5., -6.],
  17.  
    [ -7., -8., -9.],
  18.  
    [-10., -11., -12.],
  19.  
    [-13., -14., -15.]]])
  20.  
    >>> c = a.transpose(1,2)
  21.  
    >>> print(c, c.shape)
  22.  
    (tensor([[[ 1., 6., 11.],
  23.  
    [ 2., 7., 12.],
  24.  
    [ 3., 8., 13.],
  25.  
    [ 4., 9., 14.],
  26.  
    [ 5., 10., 15.]],
  27.  
     
  28.  
    [[ -1., -6., -11.],
  29.  
    [ -2., -7., -12.],
  30.  
    [ -3., -8., -13.],
  31.  
    [ -4., -9., -14.],
  32.  
    [ -5., -10., -15.]]]),
  33.  
    torch.Size([2, 5, 3]))
学新通

即使view()transpose()最终得到的Tensor的shape是一样的,但二者内容并不相同。view函数只是按照给定的(2,5,3)的Tensor维度,将元素按顺序一个个填进去;而transpose函数,则的确是在进行第一个第二维度的转置

3、view与reshape的区别

view()具有跟reshape()相同的功能,都能去重塑矩阵的形状

不同点:

reshape()方法不受此限制;如果对 tensor 调用过 transpose, permute等操作的话会使该 tensor 在内存中变得不再连续。

view():

作用:将tensor转换为指定的shape,原始的data不改变。返回的tensor与原始的tensor共享存储区。view()方法只适用于满足连续性(contiguous)条件的tensor,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。也就是说view不会改变原来数据的存放方式,并且,也不会产生数据的副本,view返回的是视图。

如果tensor 不满足连续性条件,需要先调用 contiguous()方法,但这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。

view()可以通过在某一维度输入为-1,来动态调整这个矩阵的维度的size, 而 reshape且无动态调整的功能。而且 view()用于pytorch中对张量进行处理,

view方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,返回的新tensor与源tensor共享内存,即更改其中一个,另外一个也会跟着改变

reshape():

作用:与view方法类似,将输入tensor转换为新的shape格式。

reshape方法更强大,可以认为a.reshape = a.view() a.contiguous().view()

reshape()方法的返回值既可以是视图,也可以是副本。即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同。

4、torch.flatten()

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

input (Tensor) – 输入为Tensor 
start_dim (int) – 展平的开始维度 
end_dim (int) – 展平的结束维度

展平一个连续范围的维度,输出类型为Tensor, flatten函数就是对tensor类型进行扁平化处理,也就是在不同维度上进行堆叠操作,a.flatten(m),这个意思是将a这个tensor,从第m维度开始堆叠,一直堆叠到最后一个维度

  1.  
    import torch
  2.  
    # t 是三维张量 torch.Size([3, 2, 2])
  3.  
    t = torch.tensor([[[1, 2],
  4.  
    [3, 4]],
  5.  
    [[5, 6],
  6.  
    [7, 8]],
  7.  
    [[9, 10],
  8.  
    [11, 12]]])
  9.  
    #如果不传入参数,默认开始维度为0,最后维度为-1,展开为一维
  10.  
    result_0 = torch.flatten(t)
  11.  
    print(result_0)
  12.  
    '''
  13.  
    tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
  14.  
    '''
  15.  
     
  16.  
    #当开始维度为1,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩
  17.  
    result_1 = torch.flatten(t, start_dim=1)
  18.  
    print(result_1)
  19.  
    '''
  20.  
    tensor([[ 1, 2, 3, 4],
  21.  
    [ 5, 6, 7, 8],
  22.  
    [ 9, 10, 11, 12]])
  23.  
    '''
  24.  
     
  25.  
    torch.flatten(t, start_dim=1).size()
  26.  
    # torch.Size([3, 4])
  27.  
    #下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候
  28.  
    #前面的就会合并
  29.  
    result_3 = torch.flatten(t, start_dim=0, end_dim=1)
  30.  
    print(result_3)
  31.  
    '''
  32.  
    tensor([[ 1, 2],
  33.  
    [ 3, 4],
  34.  
    [ 5, 6],
  35.  
    [ 7, 8],
  36.  
    [ 9, 10],
  37.  
    [11, 12]])
  38.  
    '''
  39.  
     
  40.  
    torch.flatten(t, start_dim=0, end_dim=1).size()
  41.  
    # torch.Size([6, 2])
学新通

示例:

  1.  
    import torch
  2.  
    # 随机产生了一个tensor,它的Batchsize是2,C是3,H是2,W是3
  3.  
    a=torch.rand(2,3,2,3)
  4.  
    print(a)
  5.  
    '''
  6.  
    tensor([[[[0.5521, 0.2547, 0.5242],
  7.  
    [0.8248, 0.4500, 0.2413]],
  8.  
     
  9.  
    [[0.7759, 0.1261, 0.0090],
  10.  
    [0.0197, 0.6191, 0.0422]],
  11.  
     
  12.  
    [[0.0896, 0.1731, 0.5484],
  13.  
    [0.7927, 0.0752, 0.2176]]],
  14.  
     
  15.  
     
  16.  
    [[[0.0118, 0.3865, 0.9587],
  17.  
    [0.6599, 0.2464, 0.0728]],
  18.  
     
  19.  
    [[0.2858, 0.3772, 0.8215],
  20.  
    [0.3267, 0.2859, 0.4329]],
  21.  
     
  22.  
    [[0.7329, 0.4436, 0.4246],
  23.  
    [0.4162, 0.8688, 0.5286]]]])
  24.  
    '''
  25.  
    ##########################################################################
  26.  
    result_0 = a.flatten(0)
  27.  
    print(result_0.shape)
  28.  
    print(result_0)
  29.  
    '''
  30.  
    torch.Size([36])
  31.  
    tensor([0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413, 0.7759, 0.1261, 0.0090,
  32.  
    0.0197, 0.6191, 0.0422, 0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176,
  33.  
    0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728, 0.2858, 0.3772, 0.8215,
  34.  
    0.3267, 0.2859, 0.4329, 0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286])
  35.  
    '''
  36.  
    ##########################################################################
  37.  
    result_1 = a.flatten(1)
  38.  
    print(result_1.shape)
  39.  
    print(result_1)
  40.  
    '''
  41.  
    torch.Size([2, 18])
  42.  
    tensor([[0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413, 0.7759, 0.1261, 0.0090,
  43.  
    0.0197, 0.6191, 0.0422, 0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176],
  44.  
    [0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728, 0.2858, 0.3772, 0.8215,
  45.  
    0.3267, 0.2859, 0.4329, 0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286]])
  46.  
    '''
  47.  
    ##########################################################################
  48.  
    result_2 = a.flatten(2)
  49.  
    print(result_2.shape)
  50.  
    print(result_2)
  51.  
    '''
  52.  
    torch.Size([2, 3, 6])
  53.  
    tensor([[[0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413],
  54.  
    [0.7759, 0.1261, 0.0090, 0.0197, 0.6191, 0.0422],
  55.  
    [0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176]],
  56.  
     
  57.  
    [[0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728],
  58.  
    [0.2858, 0.3772, 0.8215, 0.3267, 0.2859, 0.4329],
  59.  
    [0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286]]])
  60.  
    '''
  61.  
    ##########################################################################
  62.  
    result_3 = a.flatten(3)
  63.  
    print(result_3.shape)
  64.  
    print(result_3)
  65.  
    '''
  66.  
    torch.Size([2, 3, 2, 3])
  67.  
    tensor([[[[0.5521, 0.2547, 0.5242],
  68.  
    [0.8248, 0.4500, 0.2413]],
  69.  
     
  70.  
    [[0.7759, 0.1261, 0.0090],
  71.  
    [0.0197, 0.6191, 0.0422]],
  72.  
     
  73.  
    [[0.0896, 0.1731, 0.5484],
  74.  
    [0.7927, 0.0752, 0.2176]]],
  75.  
     
  76.  
     
  77.  
    [[[0.0118, 0.3865, 0.9587],
  78.  
    [0.6599, 0.2464, 0.0728]],
  79.  
     
  80.  
    [[0.2858, 0.3772, 0.8215],
  81.  
    [0.3267, 0.2859, 0.4329]],
  82.  
     
  83.  
    [[0.7329, 0.4436, 0.4246],
  84.  
    [0.4162, 0.8688, 0.5286]]]])
  85.  
    '''
  86.  
     
  87.  
    ##########################################################################
  88.  
    result_4 = a.flatten(0, 1)
  89.  
    print(result_4.shape)
  90.  
    print(result_4)
  91.  
     
  92.  
    '''
  93.  
    torch.Size([6, 2, 3])
  94.  
    tensor([[[0.5521, 0.2547, 0.5242],
  95.  
    [0.8248, 0.4500, 0.2413]],
  96.  
     
  97.  
    [[0.7759, 0.1261, 0.0090],
  98.  
    [0.0197, 0.6191, 0.0422]],
  99.  
     
  100.  
    [[0.0896, 0.1731, 0.5484],
  101.  
    [0.7927, 0.0752, 0.2176]],
  102.  
     
  103.  
    [[0.0118, 0.3865, 0.9587],
  104.  
    [0.6599, 0.2464, 0.0728]],
  105.  
     
  106.  
    [[0.2858, 0.3772, 0.8215],
  107.  
    [0.3267, 0.2859, 0.4329]],
  108.  
     
  109.  
    [[0.7329, 0.4436, 0.4246],
  110.  
    [0.4162, 0.8688, 0.5286]]])
  111.  
    '''
学新通

a.flatten(0)的意思就是从batchsize这个维度开始堆叠,直到W结束,那最后就是成一维的了,也就是只剩W这个维度,那当然就是只有一条这样子

a.flatten(1)的意思就是从C(channel)这个维度开始堆叠,直到W结束,Batchsize这个维度没有参与运算,因此还是有B这个维度的,这样的话就是相当于将三维的数据堆叠成只有一个维度W的数据,那当然就变成了两条

a.flatten(2)的意思就是从H(Height)这个维度开始堆叠,直到W结束,B和C这两个维度都没有参与运算,因此将H这个维度堆叠到W上去,就是将原本的平面变成了一个长条

最后a.flatten(3)的意思就是将H这个维度堆叠到H这个维度上去,自己堆叠自己就是没有堆叠

a.flatten(0,1), 将B的维度叠加到C的维度上,就是将两个batch叠加合并了

5、flatten函数的用法及其与reshape函数的区别

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgcechf
系列文章
更多 icon
同类精品
更多 icon
继续加载