• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

Pytorch统计网络参数计算工具、模型 FLOPs, MACs, MAdds 关系

武飞扬头像
李代数
帮助1

Pytorch统计网络参数

#网络参数数量
def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}
#查看网络参数
print(model.state_dict())

FLOPs, MACs, MAdds 关系

学新通
见文章:CNN模型复杂度(FLOPs、MAC)、参数量与运行速度

计算工具:

地址 备注
https://github.com/Lyken17/pytorch-OpCounter Pytorch
https://github.com/sovrasov/flops-counter.pytorch Pytorch
https://stackoverflow.com/questions/45085938/tensorflow-is-there-a-way-to-measure-flops-for-a-model TensorFlow: 自带tf.RunMetadata()

另:在PyTorch中,可以使用torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。

!pip install torchstat
from torchstat import stat
from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d
 
model = resnet50()
# stat打印完整信息
stat(model, (3, 224, 224))
# 模型的总参数量
total = sum([param.nelement() for param in model.parameters()])
print("Number of parameters: %.2fM" % (total/1e6))

也可以使用torchsummary

!pip install torchsummary
from torchsummary import summary
summary(model, input_size=(ch, h, w), batch_size=-1)
#ch是指输入张量的channel数量,h表示输入张量的高,w表示输入张量的宽。

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgbfebc
系列文章
更多 icon
同类精品
更多 icon
继续加载