• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

图像处理8-CNN图像分类

武飞扬头像
<Hunter>
帮助1

图像处理系列

图像处理1-经典空间域增强——灰度映射

图像处理2-经典空间域增强——直方图均衡化

图像处理3-经典空间域增强——空域滤波

图像处理4-图像的傅里叶变换

图像处理5-图片加噪

图像处理6-大津法图像阈值分切

图像处理7-图像增强

一、内容

(1)利用Pytorch搭建简单的CNN网络实现图像分类,并测试分类效果(更多步骤可参考https://www.stefanfiott.com/machine-learning/cifar-10-classifier-using-cnn-in-pytorch/);

(2)修改网络模型,进行新的训练,并测试分类效果;

(3)撰写实验报告。

二、使用简单的CNN网络进行图像分类

1.导入包

       使用图2.1的代码导入包。

学新通

图2.1 导入包

2.数据下载、增强和划分

       使用图2.2的代码进行数据下载、增强和划分数据集。

学新通

图2.2 数据下载、增强、划分

3.神经网络定义

       使用图2.3的代码定义一个简单的CNN神经网络。

学新通

图2.3 定义简单的CNN神经网络

4.定义优化器

       使用图2.4的代码定义优化器。

学新通

图2.4 定义优化器

5.训练和保存神经网络

       使用图2.5的代码训练定义的模型,并保存,输出如图2.6。

学新通

图2.5 神经网络训练与保存

学新通

图2.6 模型训练输出

6.测试神经网络

a.准确率

       使用图2.7的代码,计算准去率,最后得到准确率为62.17%

学新通

图2.7 计算模型准确率

b.计算每一类的分类准确率

       使用图2.8的代码计算每类分类的准确率,结果见图2.9。

学新通

图2.8 计算每类分类的准确率

学新通

图2.9 每类分类的准确率

c.绘制实际实际类别和预测分类曲线

       使用图2.10的代码进行绘制,并输出每个值得大小,如图2.11,曲线如图2.12。

学新通

图2.11 绘制实际实际类别和预测分类曲线

学新通

图2.12 绘制实际实际类别和预测分类值的大小

学新通

图2.13 绘制实际实际类别和预测分类值曲线

三、Geogle Net图片分类

1.导入包

       使用图3.1的代码导入包。

学新通

图3.1 导入包

2.数据导入、增强和划分

       使用图3.2的代码进行参数定义,数据导入、增强和划分。

学新通

图3.2 参数定义,数据导入、增强和划分

3.定义网络

       使用图3.3的代码进行网络定义。

学新通

学新通

 学新通

 学新通

学新通

图3.3 网络定义

4.训练网络

       使用图3.4的代码进行网络训练,输出如图3.5。

学新通

图3.4 网络训练

学新通

图3.5 训练输出

5.计算精度

       使用图3.6的代码计算精度,精度为80.43%.

学新通

图3.6 计算精度

四、总结

       使用pytorch,搭建了简单的CNN网络和Geogle Net网络,对Cifar10进行了图像分类,并计算了相应的模型校验参数。

五、源码

1.简单CNN源码

  1.  
    '''
  2.  
    简单CNN图片分类
  3.  
    '''
  4.  
    import torch
  5.  
    import torchvision
  6.  
    import torchvision.transforms as transforms
  7.  
    import matplotlib.pyplot as plt
  8.  
    import numpy as np
  9.  
    import torch.nn as nn
  10.  
    import torch.nn.functional as F
  11.  
    import torch.optim as optim
  12.  
    import os
  13.  
     
  14.  
    #下载数据并进行数据增强和划分
  15.  
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  16.  
    trainset = torchvision.datasets.CIFAR10('./data', train=True,download=True, transform=transform)
  17.  
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,num_workers=4,shuffle=True)
  18.  
    testset = torchvision.datasets.CIFAR10('./data', train=False,download=True, transform=transform)
  19.  
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,num_workers=4,shuffle=False)
  20.  
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  21.  
    dataiter = iter(trainloader)
  22.  
    images, labels = dataiter.next()
  23.  
     
  24.  
    #定义网络结构
  25.  
    class Net(nn.Module):
  26.  
    def __init__(self):
  27.  
    super(Net, self).__init__()
  28.  
    self.conv1 = nn.Conv2d(3, 6, 5)
  29.  
    self.pool = nn.MaxPool2d(2, 2)
  30.  
    self.conv2 = nn.Conv2d(6, 16, 5)
  31.  
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
  32.  
    self.fc2 = nn.Linear(120, 84)
  33.  
    self.fc3 = nn.Linear(84, 10)
  34.  
     
  35.  
    def forward(self, x):
  36.  
    x = self.pool(F.relu(self.conv1(x)))
  37.  
    x = self.pool(F.relu(self.conv2(x)))
  38.  
    x = x.view(-1, 16 * 5 * 5)
  39.  
    x = F.relu(self.fc1(x))
  40.  
    x = F.relu(self.fc2(x))
  41.  
    x = self.fc3(x)
  42.  
    return x
  43.  
    net = Net()
  44.  
     
  45.  
    #定义优化器
  46.  
    criterion = nn.CrossEntropyLoss()
  47.  
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  48.  
     
  49.  
    #训练网络
  50.  
    model_directory_path = './model、'
  51.  
    model_path = model_directory_path 'cifar-10-cnn-model-oralcnn.pt'
  52.  
    if not os.path.exists(model_directory_path):
  53.  
    os.makedirs(model_directory_path)
  54.  
    if os.path.isfile(model_path):
  55.  
    # load trained model parameters from disk
  56.  
    net.load_state_dict(torch.load(model_path))
  57.  
    print('Loaded model parameters from disk.')
  58.  
    else:
  59.  
    for epoch in range(10): # loop over the dataset multiple times
  60.  
     
  61.  
    running_loss = 0.0
  62.  
    for i, data in enumerate(trainloader, 0):
  63.  
    # get the inputs
  64.  
    inputs, labels = data
  65.  
     
  66.  
    # zero the parameter gradients
  67.  
    optimizer.zero_grad()
  68.  
     
  69.  
    # forward backward optimize
  70.  
    outputs = net(inputs)
  71.  
    loss = criterion(outputs, labels)
  72.  
    loss.backward()
  73.  
    optimizer.step()
  74.  
     
  75.  
    # print statistics
  76.  
    running_loss = loss.item()
  77.  
    if i % 2000 == 1999: # print every 2000 mini-batches
  78.  
    print('[%d, ]] loss: %.3f' %
  79.  
    (epoch 1, i 1, running_loss / 2000))
  80.  
    running_loss = 0.0
  81.  
    print('Finished Training.')
  82.  
    torch.save(net.state_dict(), model_path)
  83.  
    print('Saved model parameters to disk.')
  84.  
     
  85.  
    #网络测试
  86.  
    dataiter = iter(testloader)
  87.  
    images, labels = dataiter.next()
  88.  
    outputs = net(images)
  89.  
    sm = nn.Softmax(dim=1)
  90.  
    sm_outputs = sm(outputs)
  91.  
    probs, index = torch.max(sm_outputs, dim=1)
  92.  
    total_correct = 0
  93.  
    total_images = 0
  94.  
    confusion_matrix = np.zeros([10,10], int)
  95.  
    with torch.no_grad():
  96.  
    for data in testloader:
  97.  
    images, labels = data
  98.  
    outputs = net(images)
  99.  
    _, predicted = torch.max(outputs.data, 1)
  100.  
    total_images = labels.size(0)
  101.  
    total_correct = (predicted == labels).sum().item()
  102.  
    for i, l in enumerate(labels):
  103.  
    confusion_matrix[l.item(), predicted[i].item()] = 1
  104.  
    model_accuracy = total_correct / total_images * 100
  105.  
    print('Model accuracy on {0} test images: {1:.2f}%'.format(total_images, model_accuracy))
  106.  
    print('{0:10s} - {1}'.format('Category','Accuracy'))
  107.  
    for i, r in enumerate(confusion_matrix):
  108.  
    print('{0:10s} - {1:.1f}'.format(classes[i], r[i]/np.sum(r)*100))
  109.  
    fig, ax = plt.subplots(1,1,figsize=(8,6))
  110.  
    ax.matshow(confusion_matrix, aspect='auto', vmin=0, vmax=1000, cmap=plt.get_cmap('Blues'))
  111.  
    plt.ylabel('Actual Category')
  112.  
    plt.yticks(range(10), classes)
  113.  
    plt.xlabel('Predicted Category')
  114.  
    plt.xticks(range(10), classes)
  115.  
    plt.show()
  116.  
    print('actual/pred'.ljust(16), end='')
  117.  
    for i,c in enumerate(classes):
  118.  
    print(c.ljust(10), end='')
  119.  
    print()
  120.  
    for i,r in enumerate(confusion_matrix):
  121.  
    print(classes[i].ljust(16), end='')
  122.  
    for idx, p in enumerate(r):
  123.  
    print(str(p).ljust(10), end='')
  124.  
    print()
  125.  
     
  126.  
    r = r/np.sum(r)
  127.  
    print(''.ljust(16), end='')
  128.  
    for idx, p in enumerate(r):
  129.  
    print(str(p).ljust(10), end='')
  130.  
    print()
学新通

2.Geogle Net

  1.  
    '''
  2.  
    geoglenet图片分类
  3.  
    '''
  4.  
    import torch
  5.  
    import torchvision
  6.  
    import torchvision.transforms as transforms
  7.  
    import matplotlib.pyplot as plt
  8.  
    import numpy as np
  9.  
    import torch.nn as nn
  10.  
    import torch.nn.functional as F
  11.  
    import torch.optim as optim
  12.  
    import os
  13.  
    import gc
  14.  
     
  15.  
    ##定义参数
  16.  
    num_epochs = 40
  17.  
    batch_size = 100
  18.  
    num_classes = 10
  19.  
    learning_rate = 0.0006
  20.  
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  21.  
     
  22.  
    #下载数据并进行数据增强和划分
  23.  
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  24.  
    train_dataset = torchvision.datasets.CIFAR10('./data', download=False, train=True, transform=transform)
  25.  
    test_dataset = torchvision.datasets.CIFAR10('./data', download=False, train=False, transform=transform)
  26.  
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  27.  
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
  28.  
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  29.  
    dataiter = iter(train_loader)
  30.  
    images, labels = dataiter.next()
  31.  
     
  32.  
    #定义网络
  33.  
    class BasicConv2d(torch.nn.Module):
  34.  
    def __init__(self, in_channels, out_channels, **kwargs):
  35.  
    super(BasicConv2d, self).__init__()
  36.  
    self.conv = torch.nn.Conv2d(in_channels, out_channels, **kwargs)
  37.  
    self.batchnorm = torch.nn.BatchNorm2d(out_channels)
  38.  
    self.relu = torch.nn.ReLU(inplace=True)
  39.  
     
  40.  
    def forward(self, x):
  41.  
    x = self.conv(x)
  42.  
    x = self.batchnorm(x)
  43.  
    x = self.relu(x)
  44.  
    return x
  45.  
    # Define InceptionAux.
  46.  
    class InceptionAux(torch.nn.Module):
  47.  
    def __init__(self, in_channels, num_classes):
  48.  
    super(InceptionAux, self).__init__()
  49.  
    self.avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2)
  50.  
    self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
  51.  
    self.fc1 = torch.nn.Sequential(torch.nn.Linear(2 * 2 * 128, 256))
  52.  
    self.fc2 = torch.nn.Linear(256, num_classes)
  53.  
     
  54.  
    def forward(self, x):
  55.  
    out = self.avgpool(x)
  56.  
    out = self.conv(out)
  57.  
    out = out.view(out.size(0), -1)
  58.  
    out = torch.nn.functional.dropout(out, 0.5, training=self.training)
  59.  
    out = torch.nn.functional.relu(self.fc1(out), inplace=True)
  60.  
    out = torch.nn.functional.dropout(out, 0.5, training=self.training)
  61.  
    out = self.fc2(out)
  62.  
    return out
  63.  
    class Inception(torch.nn.Module):
  64.  
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
  65.  
    super(Inception, self).__init__()
  66.  
    self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
  67.  
    self.branch2 = torch.nn.Sequential(BasicConv2d(in_channels, ch3x3red, kernel_size=1),
  68.  
    BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1))
  69.  
    self.branch3 = torch.nn.Sequential(BasicConv2d(in_channels, ch5x5red, kernel_size=1),
  70.  
    BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2))
  71.  
    self.branch4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  72.  
    BasicConv2d(in_channels, pool_proj, kernel_size=1))
  73.  
     
  74.  
    def forward(self, x):
  75.  
    branch1 = self.branch1(x)
  76.  
    branch2 = self.branch2(x)
  77.  
    branch3 = self.branch3(x)
  78.  
    branch4 = self.branch4(x)
  79.  
     
  80.  
    outputs = [branch1, branch2, branch3, branch4]
  81.  
    return torch.cat(outputs, 1)
  82.  
    # Define GooLeNet.
  83.  
    class GoogLeNet(torch.nn.Module):
  84.  
    def __init__(self, num_classes, aux_logits=True, init_weights=False):
  85.  
    super(GoogLeNet, self).__init__()
  86.  
    self.aux_logits = aux_logits
  87.  
    self.conv1 = BasicConv2d(3, 64, kernel_size=4, stride=2, padding=3)
  88.  
    self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
  89.  
    self.conv2 = BasicConv2d(64, 64, kernel_size=1)
  90.  
    self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
  91.  
    self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
  92.  
    self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
  93.  
    self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
  94.  
    self.maxpool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
  95.  
    self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
  96.  
    self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
  97.  
    self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
  98.  
    self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
  99.  
    self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
  100.  
    self.maxpool4 = torch.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
  101.  
    self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
  102.  
    self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
  103.  
     
  104.  
    if self.aux_logits:
  105.  
    self.aux1 = InceptionAux(512, num_classes)
  106.  
    self.aux2 = InceptionAux(528, num_classes)
  107.  
     
  108.  
    self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
  109.  
    self.dropout = torch.nn.Dropout(0.4)
  110.  
    self.fc = torch.nn.Linear(1024, num_classes)
  111.  
    if init_weights:
  112.  
    self._initialize_weights()
  113.  
     
  114.  
    def forward(self, x):
  115.  
    x = self.conv1(x)
  116.  
    x = self.maxpool1(x)
  117.  
    x = self.conv2(x)
  118.  
    x = self.conv3(x)
  119.  
    x = self.maxpool2(x)
  120.  
    x = self.inception3a(x)
  121.  
    x = self.inception3b(x)
  122.  
    x = self.maxpool3(x)
  123.  
    x = self.inception4a(x)
  124.  
    if self.training and self.aux_logits: # eval model lose this layer
  125.  
    aux1 = self.aux1(x)
  126.  
    x = self.inception4b(x)
  127.  
    x = self.inception4c(x)
  128.  
    x = self.inception4d(x)
  129.  
    if self.training and self.aux_logits: # eval model lose this layer
  130.  
    aux2 = self.aux2(x)
  131.  
    x = self.inception4e(x)
  132.  
    x = self.maxpool4(x)
  133.  
    x = self.inception5a(x)
  134.  
    x = self.inception5b(x)
  135.  
    x = self.avgpool(x)
  136.  
     
  137.  
    x = torch.flatten(x, 1)
  138.  
    x = self.dropout(x)
  139.  
    x = self.fc(x)
  140.  
    if self.training and self.aux_logits: # eval model lose this layer
  141.  
    return x, aux2, aux1
  142.  
    return x
  143.  
     
  144.  
    def _initialize_weights(self):
  145.  
    for m in self.modules():
  146.  
    if isinstance(m, torch.nn.Conv2d):
  147.  
    torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  148.  
    if m.bias is not None:
  149.  
    torch.nn.init.constant_(m.bias, 0)
  150.  
    elif isinstance(m, torch.nn.Linear):
  151.  
    torch.nn.init.normal_(m.weight, 0, 0.01)
  152.  
    torch.nn.init.constant_(m.bias, 0)
  153.  
    net = GoogLeNet(10, False, True).to(device)
  154.  
    criterion = nn.CrossEntropyLoss()
  155.  
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
  156.  
     
  157.  
    def update_lr(optimizer, lr):
  158.  
    for param_group in optimizer.param_groups:
  159.  
    param_group['lr'] = lr
  160.  
     
  161.  
    #训练网络
  162.  
    total_step = len(train_loader)
  163.  
    curr_lr = learning_rate
  164.  
    for epoch in range(num_epochs):
  165.  
    gc.collect()
  166.  
    torch.cuda.empty_cache()
  167.  
    net.train()
  168.  
    for i, (images, labels) in enumerate(train_loader):
  169.  
    images = images.to(device)
  170.  
    labels = labels.to(device)
  171.  
    outputs = net(images)
  172.  
    loss = criterion(outputs, labels)
  173.  
    optimizer.zero_grad()
  174.  
    loss.backward()
  175.  
    optimizer.step()
  176.  
     
  177.  
    if (i 1) % 100 == 0:
  178.  
    print ('Epoch [{}/{}], Step [{}/{}], Loss {:.4f}'.format(epoch 1, num_epochs, i 1, total_step, loss.item()))
  179.  
    if (epoch 1) % 20 == 0:
  180.  
    curr_lr /= 3
  181.  
    update_lr(optimizer, curr_lr)
  182.  
    torch.save(net.state_dict(), './model/cifar-10-cnn-geoglenet.pt')
  183.  
     
  184.  
    #计算精度
  185.  
    net.eval()
  186.  
    with torch.no_grad():
  187.  
    correct = 0
  188.  
    total = 0
  189.  
    for images, labels in test_loader:
  190.  
    images = images.to(device)
  191.  
    labels = labels.to(device)
  192.  
    outputs = net(images)
  193.  
    _, predicted = torch.max(outputs.data, 1)
  194.  
    total = labels.size(0)
  195.  
    correct = (predicted == labels).sum().item()
  196.  
    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
学新通

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgcggkf
系列文章
更多 icon
同类精品
更多 icon
继续加载