学生极快上手向python图片分类识别器
本文着重讲不学无术的大学生如何快速上手跑出结果。本项目基于resnet34识别四类示意图,由cat vs dog项目改写而来。文末会说明如何快速把它改成你想要的项目(图片二分类等)。
项目代码、数据集下载:ht删tps://p除an.bai中du.c文om/s/1F打aI6hKNPB_0w_oed9H开0STg 提取码: z5v5
1.各文件/文件夹作用
自上到下:
checkpoints 储存每个epoch训练后的模型
datasets 储存训练集、测试集
image 用来给数据集做重命名,后面会提到
result 似乎没用过?
图片分类结果 手动分类的数据集。将示意图分四类,每类约150张
config 储存模型相关参数完全不用修改
dataset 数据集预处理等工作。
rename 数据集图片重命名用,后面会讲
test_model是从checkpoints里取出来训练好的模型改个名,文件夹里是我们的模型
test 测试程序,train 训练程序。
2.如何运行项目
先自己看import哪些库,装好库
①图片重命名
我使用的数据集存在图片分类结果文件夹了,你也可以不用它。
把分类好的四类图片中任一类(如sketch1)全部放入image/raw。
将rename.py中的label = 'sketch4'改成label = 'sketch1'
index_list = [i for i in range(52, imgs_num 52)]也要根据图片数量做调整相信废物大学生也能看得懂
运行rename.py会在image/processed生成重命名好的图片。格式为sktech1.0.jpg、sktech1.1.jpg、sktech1.2.jpg等。将这些图片二八分开分别放入datasets/test和datasets/train
四类图片都要这样处理。
需要注意的是,最后无论是test文件夹还是train文件夹,图片的id不能重复,比如sktech1.0.jpg里0就是id。不能同时存在sktech1.0.jpg和sktech2.0.jpg 。
②运行train.py训练模型。
此时checkpoints文件夹里会多出来很多模型,同时shell会输出正确率。当你认为正确率够高就可以停了,从checkpoints拿出最新的模型改名为test_model,拿到主目录替换我们的模型。
③运行test.py输出正确率。
此时项目运行完成。
3.Q&A
①老师的要求是分类其他类型的图片,不是你给的示意图。怎么办?
答:用你自己的数据集即可。不知道怎么找数据集可以评论区问。
②老师的要求是图片的二/三分类,怎么修改代码?
答:以二分类为例。修改以下代码:
datasets.py:第60行
从四类改两类。
rename.py:重命名图片跟着上面步骤做。
test_modification.py:
29行的model.fc = nn.Linear(512, 4) 把4改成2.
48行(下图)改2类
72行同理:
train.py:
30行model.fc = nn.Linear(512,4) 把4改成2
110行confusion_matrix = meter.ConfusionMeter(4) 把4改2
120行accuracy = 100.* (cm_value[0][0] cm_value[1][1] cm_value[2][2] cm_value[3][3]) / (cm_value.sum()) 把cm_value[2][2] cm_value[3][3])删掉,只留两类。
应该就这些,改不好来评论区问。
③你这项目没做可视化啊?
答:确实。
本文结束
以下代码无关本文,仅充数用
-
# coding=utf-8
-
-
""" test
-
使用测试集测试模型结果
-
"""
-
-
from config import _setting_
-
import os
-
import torch as t
-
from dataset import NatureSketchClassification
-
from torch.utils.data import DataLoader
-
from torchnet import meter
-
from torch.autograd import Variable
-
from torchvision import models
-
from torch import nn
-
import time
-
import csv
-
-
-
""""""
-
def test(**kwargs):
-
# set data
-
test_data = NatureSketchClassification(_setting_.test_data_root, test=True)
-
test_dataloader = DataLoader(test_data, batch_size=_setting_.batch_size, shuffle=False, num_workers=_setting_.num_workers)
-
results = []
-
-
# set model
-
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
-
model.fc = nn.Linear(512, 4)
-
model.load_state_dict(t.load('./test_model.pth', map_location='cpu'))
-
model.eval()
-
-
for id, (data, path) in enumerate(test_dataloader):
-
# input = Variable(data,volatile=True)
-
-
with t.no_grad():
-
input = Variable(data)
-
-
score = model(input)
-
print('score=',score)#检验score
-
path = path.numpy().tolist()
-
_,predicted = t.max(score.data,1)
-
#Modification
-
predicted = predicted.data.cpu().numpy().tolist()
-
res = ""
-
print('predicted=',predicted)#检验predicted
-
#Modification
-
for (i, j) in zip(path, predicted):
-
if j == 0:
-
res = "sketch1"
-
elif j == 1:
-
res = "sketch2"
-
elif j == 2:
-
res = "sketch3"
-
elif j == 3:
-
res = "sketch4"
-
print('res=',res)#检验res(result)
-
results.append([i,"".join(res)])
-
-
-
res = []
-
truth = ""
-
compare = ""
-
imgs = [os.path.join(_setting_.test_data_root,img) for img in os.listdir(_setting_.test_data_root)] #获取root路径下所有图片的地址
-
imgs_num = len(imgs) # 图片数量
-
NumofCorrect = 0
-
imgs = sorted(imgs,key=lambda x: int(x.split('.')[-2].split('/')[-1])) # 按序号排序
-
for image in imgs:
-
id = int(image.split('.')[-2].split('/')[-1]) # 获取id
-
#Modification
-
-
if 'sketch1' in image.split('/')[-1]:
-
truth = 'sketch1'
-
elif 'sketch2' in image.split('/')[-1]:
-
truth = 'sketch2'
-
elif 'sketch3' in image.split('/')[-1]:
-
truth = 'sketch3'
-
else:
-
truth = 'sketch4'
-
print('truth=',truth)
-
#truth = 'nature' if 'nature' in image.split('/')[-1] else 'sketch' # 获取图片的真实分类
-
compare = 'true' if truth == results[id - 1][1] else 'false'
-
if compare == 'true':
-
NumofCorrect = NumofCorrect 1
-
res.append([results[id - 1][0], results[id - 1][1], "".join(truth), compare])
-
-
Accuracy = NumofCorrect / imgs_num * 100
-
round(Accuracy, 2)
-
write_csv(res, _setting_.result_file, Accuracy)
-
-
for id, label, truth, compare in res:
-
if compare == 'false':
-
print("number: " str(id) ", res: " label ", truth: " truth ", IsCorrect: " compare)
-
print("Accuracy: " str(Accuracy))
-
return results
-
-
-
""""""
-
def write_csv(results, file_name, acc):
-
Accuracy = []
-
Accuracy.append([" ", "Accuracy", "".join(str(acc))])
-
with open(file_name, "w") as f:
-
writer = csv.writer(f)
-
writer.writerow(['id', 'label', 'truth', 'IsCorrect'])
-
writer.writerows(results)
-
writer.writerows(Accuracy)
-
-
if __name__ == '__main__':
-
test()
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhfikcif
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
怎样阻止微信小程序自动打开
PHP中文网 06-13 -
excel下划线不显示怎么办
PHP中文网 06-23 -
excel打印预览压线压字怎么办
PHP中文网 06-22 -
photoshop蒙版画笔没反应怎么办
PHP中文网 06-24