• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

Pytorch训练的.pth模型转为tensorflow的.pb模型的坑和解决方法

武飞扬头像
Anaconda_
帮助2

我用pytorch训练YOLOv4模型,训练时的log想转为tf的.pb以执行后续的工作。在转换过程中踩了不少坑,在此记录。

1. Yolo模型

网上很多转换的方法,都大同小异。但需要注意,yolo训练的log只是参数,没有网络架构,因此需要导入自己的YOLO_body,再把log的参数对应上。

  1.  
    import torch
  2.  
    from nets.yolo4 import YoloBody # 我直接在yolov4的文件夹下新建的文件,所以这里直接导入了
  3.  
     
  4.  
     
  5.  
    model = YoloBody(3,3) # YoloBody(num_anchors, number_classes) 修改成自己的锚框数和类别数
  6.  
     
  7.  
    model.load_state_dict(torch.load('model_data\yolo4_weights.pth')) # 把参数和架构对上
  8.  
     

如果你的.pth只是参数,而没有架构,会出现报错如:【TypeError: ‘collections.OrderedDict‘ object is not callable...】【 object has no attribute 'state_dict'】这类的,解决可参考https://blog.csdn.net/xiaoqiaoliushuiCC/article/details/114386432

如果是你自己训练的yolo模型,记得修改yolo.py中的那些路径和变量,和自己的数据对上,不然也会报错。

2.转ONNX

  1.  
    import torch
  2.  
    import torch.nn as nn
  3.  
    import torch.onnx
  4.  
    import onnx
  5.  
    #from onnx_tf.backend import prepare
  6.  
    import argparse
  7.  
    import os
  8.  
     
  9.  
    dummy_input = torch.randn(1, 3, 608, 608, device='cpu')
  10.  
     
  11.  
    input_names=['input1']
  12.  
    output_names=['output1']
  13.  
     
  14.  
    torch.onnx.export(model, dummy_input, "insight.onnx", verbose=True,
  15.  
    input_names=input_names, output_names=output_names)
学新通

若报错:Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same 若报错这个,则改dummy...中的device。这里一般是没啥问题。

3.转tensorflow

  1.  
    import onnx
  2.  
    from onnx_tf.backend import prepare
  3.  
    import tensorflow as tf
  4.  
     
  5.  
    model = onnx.load('./insight.onnx')
  6.  
    tf_model = prepare(model)
  7.  
    tf_model.export_graph('./insight.pb')

成功的化是一个.pb,而不是文件夹

那么变成文件夹很有可能是版本不对,tensorflow这点还是挺讨厌的。

我转换时候的版本如下:

tensorflow                    2.4.2
tensorflow-addons             0.15.0
tensorflow-estimator          2.4.0
tensorflow-gpu                2.5.0
tensorflow-gpu-estimator      2.2.0

onnx                          1.8.0
onnx-tf                       1.6.0
onnxruntime                   1.10.0

onnx-tf的1.6.0版本要去github上下载,百度一下就好

4.其他报错

期间在安装包等过程中遇到的报错。

1.【AttributeError: module 'tensorflow' has no attribute 'gfile'】这个因为tensorflow版本不一样,解决如下:

  1.  
    tf.compat.v1.GraphDef() # -> instead of tf.GraphDef()
  2.  
    tf.compat.v2.io.gfile.GFile() # -> instead of tf.gfile.GFile()

或者最简单的方法就是把tensorflow 2.x降版本,降到1.14(个人感觉最稳的)。但好像onnx需要tensorflow>2.2.0,所以就手动把版本不兼容的代码改了。

2.【ERROR: Could not install packages due to an EnvironmentError: [WinError 5] 拒绝访问】

pip install --user xxxxx 

若不行,升级pip。

若还是不行,把包卸载了重新装。

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgbeibb
系列文章
更多 icon
同类精品
更多 icon
继续加载