PyTorch的matmul函数
PyTorch中的两个张量的乘法可以分为两种:
-
两个张量对应的元素相乘(element-wise),在PyTorch中可以通过
torch.mul
函数(或者 ∗ * ∗运算符)实现 -
两个张量矩阵相乘(Matrix product),在PyTorch中可以通过
torch.matmul
函数实现
本文主要介绍两个张量的矩阵相乘。
语法为:
torch.matmul(input, other, out = None)
函数对input和other两个张量进行矩阵相乘。为了方便后续的讲解,将input记为a,将other记为b。
点积在数学中,又称数量积,是指接受在实数R上的两个1D张量并返回一个实数值0D张量的二元运算。
若1D张量a=[1,2],1D张量b=[3,4],则:
a ⋅ \cdot ⋅b=1 × \times × 3 2 × \times × 4 = 11
- 若a为1D张量,b为1D张量,则返回两个张量的点积,则返回两个张量的点积(此时的torch.matmul不支持out参数)
举例如下:
import torch
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
result = torch.matmul(a, b)
print(result)
结果为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor(11)
- 若a为2D张量,b为2D张量,则返回两个张量的矩阵乘积。
矩阵相乘最重要的方法是一般矩阵乘积,它只有在第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同时才有意义。
若2D张量a=[[1,2],[3,4]],2D张量b=[[5,6,7],[8,9,10]],则:
a × \times × b=[[21,24,27],[47,54,61]],2D张量a的形状为(2,2),而2D张量b的形状(2,3)。矩阵乘积的运算规则:
举例为:
import torch
a = torch.tensor([[1, 2],[3,4]])
b = torch.tensor([[5,6,7],[8,9,10]])
result = torch.matmul(a, b)
print(result)
结果展示为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor([[21, 24, 27],
[47, 54, 61]])
- 若a为1D张量,b为2D张量,torch.matmul函数:
首先,在1D张量a的前面插入一个长度为1的新维度变成2D张量;
然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;
最后,将矩阵乘积结果中长度为1的维度(前面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果。
import torch
a = torch.tensor([1, 2])
b = torch.tensor([[5, 6, 7],[8, 9, 10]])
result = torch.matmul(a, b)
print(result, result.shape)
结果为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor([21, 24, 27]) torch.Size([3])
简单来说,先将1D张量a扩展成2D张量,满足矩阵乘积的条件下,将两个2D张量进行矩阵乘积的运算。
此时得到的形状是(1,3)的2D张量,最后将前面插入长度为1的新维度删除即为最终torch.matmul(a, b)函数返回的结果。
- 若a为2D张量,b为1D张量,torch.matmul函数:
首先,在1D张量b的后面插入一个长度为1的新维度变成2D张量;
然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;
最后,将矩阵乘积结果中长度为1的维度(后面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果;
import torch
b = torch.tensor([1, 2, 3])
a = torch.tensor([[5, 6, 7],[8, 9, 10]])
result = torch.matmul(a, b)
print(result, result.shape)
结果展示为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor([38, 56]) torch.Size([2])
其中:
38 = 15 26 3*7
56 = 18 29 3*10
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhghbcfg
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
excel打印预览压线压字怎么办
PHP中文网 06-22 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01 -
怎样阻止微信小程序自动打开
PHP中文网 06-13