Pytorch训练的.pth模型转为tensorflow的.pb模型的坑和解决方法
我用pytorch训练YOLOv4模型,训练时的log想转为tf的.pb以执行后续的工作。在转换过程中踩了不少坑,在此记录。
1. Yolo模型
网上很多转换的方法,都大同小异。但需要注意,yolo训练的log只是参数,没有网络架构,因此需要导入自己的YOLO_body,再把log的参数对应上。
-
import torch
-
from nets.yolo4 import YoloBody # 我直接在yolov4的文件夹下新建的文件,所以这里直接导入了
-
-
-
model = YoloBody(3,3) # YoloBody(num_anchors, number_classes) 修改成自己的锚框数和类别数
-
-
model.load_state_dict(torch.load('model_data\yolo4_weights.pth')) # 把参数和架构对上
-
如果你的.pth只是参数,而没有架构,会出现报错如:【TypeError: ‘collections.OrderedDict‘ object is not callable...】【 object has no attribute 'state_dict'】这类的,解决可参考https://blog.csdn.net/xiaoqiaoliushuiCC/article/details/114386432。
如果是你自己训练的yolo模型,记得修改yolo.py中的那些路径和变量,和自己的数据对上,不然也会报错。
2.转ONNX
-
import torch
-
import torch.nn as nn
-
import torch.onnx
-
import onnx
-
#from onnx_tf.backend import prepare
-
import argparse
-
import os
-
-
dummy_input = torch.randn(1, 3, 608, 608, device='cpu')
-
-
input_names=['input1']
-
output_names=['output1']
-
-
torch.onnx.export(model, dummy_input, "insight.onnx", verbose=True,
-
input_names=input_names, output_names=output_names)
若报错:Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same 若报错这个,则改dummy...中的device。这里一般是没啥问题。
3.转tensorflow
-
import onnx
-
from onnx_tf.backend import prepare
-
import tensorflow as tf
-
-
model = onnx.load('./insight.onnx')
-
tf_model = prepare(model)
-
tf_model.export_graph('./insight.pb')
成功的化是一个.pb,而不是文件夹
那么变成文件夹很有可能是版本不对,tensorflow这点还是挺讨厌的。
我转换时候的版本如下:
tensorflow 2.4.2
tensorflow-addons 0.15.0
tensorflow-estimator 2.4.0
tensorflow-gpu 2.5.0
tensorflow-gpu-estimator 2.2.0
onnx 1.8.0
onnx-tf 1.6.0
onnxruntime 1.10.0
onnx-tf的1.6.0版本要去github上下载,百度一下就好
4.其他报错
期间在安装包等过程中遇到的报错。
1.【AttributeError: module 'tensorflow' has no attribute 'gfile'】这个因为tensorflow版本不一样,解决如下:
-
tf.compat.v1.GraphDef() # -> instead of tf.GraphDef()
-
tf.compat.v2.io.gfile.GFile() # -> instead of tf.gfile.GFile()
或者最简单的方法就是把tensorflow 2.x降版本,降到1.14(个人感觉最稳的)。但好像onnx需要tensorflow>2.2.0,所以就手动把版本不兼容的代码改了。
2.【ERROR: Could not install packages due to an EnvironmentError: [WinError 5] 拒绝访问】
pip install --user xxxxx
若不行,升级pip。
若还是不行,把包卸载了重新装。
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhgbeibb
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
excel打印预览压线压字怎么办
PHP中文网 06-22 -
怎样阻止微信小程序自动打开
PHP中文网 06-13 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01