• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

Pytorch机器学习十—— 目标检测k-means聚类方法生成锚框anchor

武飞扬头像
lzzzzzzm
帮助1

Pytorch机器学习(十)—— YOLO中k-means聚类方法生成锚框anchor




前言

前面文章说过有关锚框的一些知识,但有个坑一直没填,就是在YOLO中锚框的大小是如何确定出来的。其实在YOLOV3中就有采用k-means聚类方法计算锚框的方法,而在YOLOV5中作者在基于k-means聚类方法的结果之后,采用了遗传算法,进一步得到效果更好的锚框。

如果对锚框概念不理解的,可以看一下这篇文章

Pytorch机器学习(九)—— YOLO中对于锚框,预测框,产生候选区域及对候选区域进行标注详解


一、K-means聚类

在YOLOV3中,锚框大小的计算就是采用的k-means聚类的方法形成的。

从直观的理解,我们知道所有已经标注的bbox的长宽大小,而锚框则是对于预测这些bbox的潜在候选框,所以锚框的长宽形状应该越接近真实bbox越好。而又由于YOLO网络的预测层是包含3种尺度的信息的(分别对应3种感受野),每种尺度的anchor又是三种,所以我们就需要9种尺度的anchor,也即我们需要对所有的bbox的尺寸聚类成9种类别!!

聚类方法比较常用的是使用k-means聚类方法,其算法流程如下。

  • 从数据集中随机选取 K 个点作为初始聚类的中心,中心点为学新通 
  • 针对数据集中每个样本 xi,计算它们到各个聚类中心点的距离,到哪个聚类中心点的距离最小,就将其划分到对应聚类中心的类中
  • 针对每个类别 i ,重新计算该类别的聚类中心  学新通(其中 | ||i| 表示的是该类别数据的总个数)
  • 重复第二步和第三步,直到聚类中心的位置不再发生变化(我们也可以设置迭代次数)

 k-means代码

  1.  
    # 计算中心点和其他点直接的距离
  2.  
    def calc_distance(obs, guess_central_points):
  3.  
    """
  4.  
     
  5.  
    :param obs: 所有的观测点
  6.  
    :param guess_central_points: 中心点
  7.  
    :return:每个点对应中心点的距离
  8.  
    """
  9.  
    distances = []
  10.  
    for x, y in obs:
  11.  
    distance = []
  12.  
    for xc, yc in guess_central_points:
  13.  
    distance.append(math.dist((x, y), (xc, yc)))
  14.  
    distances.append(distance)
  15.  
     
  16.  
    return distances
  17.  
     
  18.  
     
  19.  
    def k_means(obs, k, dist=np.median):
  20.  
    """
  21.  
     
  22.  
    :param obs: 待观测点
  23.  
    :param k: 聚类数k
  24.  
    :param dist: 表征聚类中心函数
  25.  
    :return: guess_central_points中心点
  26.  
    current_cluster 分类结果
  27.  
    """
  28.  
    obs_num = obs.shape[0]
  29.  
    if k < 1:
  30.  
    raise ValueError("Asked for %d clusters." % k)
  31.  
    # 随机取中心点
  32.  
    guess_central_points = obs[np.random.choice(obs_num, size=k, replace=False)] # 初始化最大距离
  33.  
    last_cluster = np.zeros((obs_num, ))
  34.  
     
  35.  
    # 当小于一定值时聚类完成
  36.  
    while True:
  37.  
    # 关键是下面的calc_distance,来计算需要的距离
  38.  
    distances = calc_distance(obs, guess_central_points)
  39.  
    # 获得对应距离最小值的索引
  40.  
    current_cluster = np.argmin(distances, axis=1)
  41.  
    # 如果聚类类别没有改变, 则直接退出
  42.  
    if (last_cluster == current_cluster).all():
  43.  
    break
  44.  
     
  45.  
    # 计算新的中心
  46.  
    for i in range(k):
  47.  
    guess_central_points[i] = dist(obs[current_cluster == i], axis=0)
  48.  
     
  49.  
    last_cluster = current_cluster
  50.  
     
  51.  
    return guess_central_points, current_cluster
学新通

 聚类效果如下

学新通

k-means 算法

还有一种k-means 算法,是属于k-means算法的衍生吧,其主要解决的是k-means算法第一步,随机选择中心点的问题。

学新通

整个代码也十分简单,只需要把最先随机选取中心点用下面代码计算出来就可以。

  1.  
    # k_means 计算中心坐标
  2.  
    def calc_center(boxes):
  3.  
    box_number = boxes.shape[0]
  4.  
    # 随机选取第一个中心点
  5.  
    first_index = np.random.choice(box_number, size=1)
  6.  
    clusters = boxes[first_index]
  7.  
    # 计算每个样本距中心点的距离
  8.  
    dist_note = np.zeros(box_number)
  9.  
    dist_note = np.inf
  10.  
    for i in range(k):
  11.  
    # 如果已经找够了聚类中心,则退出
  12.  
    if i 1 == k:
  13.  
    break
  14.  
    # 计算当前中心点和其他点的距离
  15.  
    for j in range(box_number):
  16.  
    j_dist = single_distance(boxes[j], clusters[i])
  17.  
    if j_dist < dist_note[j]:
  18.  
    dist_note[j] = j_dist
  19.  
    # 转换为概率
  20.  
    dist_p = dist_note / dist_note.sum()
  21.  
    # 使用赌轮盘法选择下一个点
  22.  
    next_index = np.random.choice(box_number, 1, p=dist_p)
  23.  
    next_center = boxes[next_index]
  24.  
    clusters = np.vstack([clusters, next_center])
  25.  
    return clusters
学新通

但我自己在使用过程中,对于提升不大。主要因为其实bbox的尺度差异一般不会太大,所以这个中心点的选取,对于最后影响不大。


二、YOLO中使用k-means聚类生成anchor

下面重点说一下如何使用这个k-means算法来生成anchor,辅助我们训练,下面的代码和上面的有一点不一样,因为我们上面的代码是基于点的(x,y),而我们聚类中,是bbox的(w,h),下面代码都以VOC格式的训练集为例,如果是coco格式的,得麻烦你自己转一下格式了。

如果不想用我下面的代码但也想用k-means聚类,请读取自己数据集时,读取bbox和图片的(w,h)以列表的形式保存,确保自己n*2或者m*2的列表.

读取VOC格式数据集

我下面的代码,不仅读取了voc格式的数据集,还做了一些数据的统计,如果不想要,自己注释点就好,代码比较简单,也写了注释。

大家可以不用太纠结代码实现,记得改一下自己的图片路径即可。

  1.  
    from xml.dom.minidom import parse
  2.  
    import matplotlib.pyplot as plt
  3.  
    import cv2 as cv
  4.  
    import os
  5.  
    train_annotation_path = '/home/aistudio/data/train/Annotations' # 训练集annotation的路径
  6.  
    train_image_path = '/home/aistudio/data/train/JPEGImages' # 训练集图片的路径
  7.  
    # 展示图片的数目
  8.  
    show_num = 12
  9.  
    #打开xml文档
  10.  
     
  11.  
    def parase_xml(xml_path):
  12.  
    """
  13.  
    输入:xml路径
  14.  
    返回:image_name, width, height, bboxes
  15.  
    """
  16.  
    domTree = parse(xml_path)
  17.  
    rootNode = domTree.documentElement
  18.  
    # 得到object,sizem,图片名称属性
  19.  
    object_node = rootNode.getElementsByTagName("object")
  20.  
    shape_node = rootNode.getElementsByTagName("size")
  21.  
    image_node = rootNode.getElementsByTagName("filename")
  22.  
    image_name = image_node[0].childNodes[0].data
  23.  
    bboxes = []
  24.  
    # 解析图片的长宽
  25.  
    for size in shape_node:
  26.  
    width = int(size.getElementsByTagName('width')[0].childNodes[0].data)
  27.  
    height = int(size.getElementsByTagName('height')[0].childNodes[0].data)
  28.  
    # 解析图片object属性
  29.  
    for obj in object_node:
  30.  
    # 解析name属性,并统计类别数
  31.  
    class_name = obj.getElementsByTagName("name")[0].childNodes[0].data
  32.  
    # 解析bbox属性,并统计bbox的大小
  33.  
    bndbox = obj.getElementsByTagName("bndbox")
  34.  
     
  35.  
    for bbox in bndbox:
  36.  
    x1 = int(bbox.getElementsByTagName('xmin')[0].childNodes[0].data)
  37.  
    y1 = int(bbox.getElementsByTagName('ymin')[0].childNodes[0].data)
  38.  
    x2 = int(bbox.getElementsByTagName('xmax')[0].childNodes[0].data)
  39.  
    y2 = int(bbox.getElementsByTagName('ymax')[0].childNodes[0].data)
  40.  
    bboxes.append([class_name, x1, y1, x2, y2])
  41.  
    return image_name, width, height, bboxes
  42.  
     
  43.  
    def read_voc(train_annotation_path, train_image_path, show_num):
  44.  
    """
  45.  
    train_annotation_path:训练集annotation的路径
  46.  
    train_image_path:训练集图片的路径
  47.  
    show_num:展示图片的大小
  48.  
    """
  49.  
    # 用于统计图片的长宽
  50.  
    total_width, total_height = 0, 0
  51.  
    # 用于统计图片bbox长宽
  52.  
    bbox_total_width, bbox_total_height, bbox_num = 0, 0, 0
  53.  
    min_bbox_size = 40000
  54.  
    max_bbox_size = 0
  55.  
    # 用于统计聚类所用的图片长宽,bbox长宽
  56.  
    img_wh = []
  57.  
    bbox_wh = []
  58.  
    # 用于统计标签
  59.  
    total_size = []
  60.  
    class_static = {'crazing': 0, 'inclusion': 0, 'patches': 0, 'pitted_surface': 0, 'rolled-in_scale': 0, 'scratches': 0}
  61.  
    num_index = 0
  62.  
     
  63.  
    for root, dirs, files in os.walk(train_annotation_path):
  64.  
    for file in files:
  65.  
    num_index = 1
  66.  
    xml_path = os.path.join(root, file)
  67.  
    image_name, width, height, bboxes = parase_xml(xml_path)
  68.  
    image_path = os.path.join(train_image_path, image_name)
  69.  
    img_wh.append([width, height])
  70.  
    total_width = width
  71.  
    total_height = height
  72.  
     
  73.  
    # 如果需要展示,则读取图片
  74.  
    if num_index < show_num:
  75.  
    image_path = os.path.join(train_image_path, image_name)
  76.  
    image = cv.imread(image_path)
  77.  
    # 统计有关bbox的信息
  78.  
    wh = []
  79.  
    for bbox in bboxes:
  80.  
    class_name = bbox[0]
  81.  
    class_static[class_name] = 1
  82.  
    x1, y1, x2, y2 = bbox[1], bbox[2], bbox[3], bbox[4]
  83.  
    bbox_width = x2 - x1
  84.  
    bbox_height = y2 - y1
  85.  
    bbox_size = bbox_width*bbox_height
  86.  
    # 统计bbox的最大最小尺寸
  87.  
    if min_bbox_size > bbox_size:
  88.  
    min_bbox_size = bbox_size
  89.  
    if max_bbox_size < bbox_size:
  90.  
    max_bbox_size = bbox_size
  91.  
    total_size.append(bbox_size)
  92.  
    # 统计bbox平均尺寸
  93.  
    bbox_total_width = bbox_width
  94.  
    bbox_total_height = bbox_height
  95.  
    # 用于聚类使用
  96.  
    wh.append([bbox_width / width, bbox_height / height]) # 相对坐标
  97.  
    bbox_num = 1
  98.  
    # 如果需要展示,绘制方框
  99.  
    if num_index < show_num:
  100.  
    cv.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
  101.  
    cv.putText(image, class_name, (x1, y1 10), cv.FONT_HERSHEY_SIMPLEX, fontScale=0.2, color=(0, 255, 0), thickness=1)
  102.  
    bbox_wh.append(wh)
  103.  
    # 如果需要展示
  104.  
    if num_index < show_num:
  105.  
    plt.figure()
  106.  
    plt.imshow(image)
  107.  
    plt.show()
  108.  
     
  109.  
     
  110.  
    # 去除2个检查文件
  111.  
    # num_index -= 2
  112.  
    print("total train num is: {}".format(num_index))
  113.  
    print("avg total_width is {}, avg total_height is {}".format((total_width / num_index), (total_height / num_index)))
  114.  
    print("avg bbox width is {}, avg bbox height is {} ".format((bbox_total_width / bbox_num), (bbox_total_height / bbox_num)))
  115.  
    print("min bbox size is {}, max bbox size is {}".format(min_bbox_size, max_bbox_size))
  116.  
    print("class_static show below:", class_static)
  117.  
     
  118.  
    return img_wh, bbox_wh
  119.  
     
  120.  
    img_wh, bbox_wh = read_voc(train_annotation_path, train_image_path, show_num)
学新通

k-means聚类生成anchor

我这里的k-means代码集合了k-means 的实现,也集合了 太阳花的小绿豆这位博主提出用IOU作为评价指标来计算k-means而不是用欧拉距离的方法可以测试发现,使用IOU确实效果要比使用欧拉距离做为评价指标要好)

  1.  
    import numpy as np
  2.  
    # 这里IOU的概念更像是只是考虑anchor的长宽
  3.  
    def wh_iou(wh1, wh2):
  4.  
    # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  5.  
    wh1 = wh1[:, None] # [N,1,2]
  6.  
    wh2 = wh2[None] # [1,M,2]
  7.  
    inter = np.minimum(wh1, wh2).prod(2) # [N,M]
  8.  
    return inter / (wh1.prod(2) wh2.prod(2) - inter) # iou = inter / (area1 area2 - inter)
  9.  
     
  10.  
    # 计算单独一个点和一个中心的距离
  11.  
    def single_distance(center, point):
  12.  
    center_x, center_y = center[0]/2 , center[1]/2
  13.  
    point_x, point_y = point[0]/2, point[1]/2
  14.  
    return np.sqrt((center_x - point_x)**2 (center_y - point_y)**2)
  15.  
     
  16.  
    # 计算中心点和其他点直接的距离
  17.  
    def calc_distance(boxes, clusters):
  18.  
    """
  19.  
     
  20.  
    :param obs: 所有的观测点
  21.  
    :param clusters: 中心点
  22.  
    :return:每个点对应中心点的距离
  23.  
    """
  24.  
    distances = []
  25.  
    for box in boxes:
  26.  
    # center_x, center_y = x/2, y/2
  27.  
    distance = []
  28.  
    for center in clusters:
  29.  
    # center_xc, cneter_yc = xc/2, yc/2
  30.  
    distance.append(single_distance(box, center))
  31.  
    distances.append(distance)
  32.  
     
  33.  
    return distances
  34.  
     
  35.  
    # k_means 计算中心坐标
  36.  
    def calc_center(boxes, k):
  37.  
    box_number = boxes.shape[0]
  38.  
    # 随机选取第一个中心点
  39.  
    first_index = np.random.choice(box_number, size=1)
  40.  
    clusters = boxes[first_index]
  41.  
    # 计算每个样本距中心点的距离
  42.  
    dist_note = np.zeros(box_number)
  43.  
    dist_note = np.inf
  44.  
    for i in range(k):
  45.  
    # 如果已经找够了聚类中心,则退出
  46.  
    if i 1 == k:
  47.  
    break
  48.  
    # 计算当前中心点和其他点的距离
  49.  
    for j in range(box_number):
  50.  
    j_dist = single_distance(boxes[j], clusters[i])
  51.  
    if j_dist < dist_note[j]:
  52.  
    dist_note[j] = j_dist
  53.  
    # 转换为概率
  54.  
    dist_p = dist_note / dist_note.sum()
  55.  
    # 使用赌轮盘法选择下一个点
  56.  
    next_index = np.random.choice(box_number, 1, p=dist_p)
  57.  
    next_center = boxes[next_index]
  58.  
    clusters = np.vstack([clusters, next_center])
  59.  
    return clusters
  60.  
     
  61.  
     
  62.  
    # k-means聚类,且评价指标采用IOU
  63.  
    def k_means(boxes, k, dist=np.median, use_iou=True, use_pp=False):
  64.  
    """
  65.  
    yolo k-means methods
  66.  
    Args:
  67.  
    boxes: 需要聚类的bboxes,bboxes为n*2包含w,h
  68.  
    k: 簇数(聚成几类)
  69.  
    dist: 更新簇坐标的方法(默认使用中位数,比均值效果略好)
  70.  
    use_iou:是否使用IOU做为计算
  71.  
    use_pp:是否是同k-means 算法
  72.  
    """
  73.  
    box_number = boxes.shape[0]
  74.  
    last_nearest = np.zeros((box_number,))
  75.  
    # 在所有的bboxes中随机挑选k个作为簇的中心
  76.  
    if not use_pp:
  77.  
    clusters = boxes[np.random.choice(box_number, k, replace=False)]
  78.  
    # k_means 计算初始值
  79.  
    else:
  80.  
    clusters = calc_center(boxes, k)
  81.  
     
  82.  
    # print(clusters)
  83.  
    while True:
  84.  
    # 计算每个bboxes离每个簇的距离 1-IOU(bboxes, anchors)
  85.  
    if use_iou:
  86.  
    distances = 1 - wh_iou(boxes, clusters)
  87.  
    else:
  88.  
    distances = calc_distance(boxes, clusters)
  89.  
    # 计算每个bboxes距离最近的簇中心
  90.  
    current_nearest = np.argmin(distances, axis=1)
  91.  
    # 每个簇中元素不在发生变化说明以及聚类完毕
  92.  
    if (last_nearest == current_nearest).all():
  93.  
    break # clusters won't change
  94.  
    for cluster in range(k):
  95.  
    # 根据每个簇中的bboxes重新计算簇中心
  96.  
    clusters[cluster] = dist(boxes[current_nearest == cluster], axis=0)
  97.  
     
  98.  
    last_nearest = current_nearest
  99.  
     
  100.  
    return clusters
学新通

使用我下面的auot_anchor代码注意!!(这里代码也是借鉴的太阳花的小绿豆博主的,他把里面的torch函数改为np函数后,使得代码移植性变强了!)

传入的参数中img_wh和bbox_wh即读取voc数据集中图片的长宽和bbox的长宽,为n*2和m*2的列表 !!

这里我还加入了YOLOV5中的遗传算法,具体细节就不展开了。

  1.  
    from tqdm import tqdm
  2.  
    import random
  3.  
    # 计算聚类和遗传算法出来的anchor和真实bbox之间的重合程度
  4.  
    def anchor_fitness(k: np.ndarray, wh: np.ndarray, thr: float): # mutation fitness
  5.  
    """
  6.  
    输入:k:聚类完后的结果,且排列为升序
  7.  
    wh:包含bbox中w,h的集合,且转换为绝对坐标
  8.  
    thr:bbox中和k聚类的框重合阈值
  9.  
    """
  10.  
    r = wh[:, None] / k[None]
  11.  
    x = np.minimum(r, 1. / r).min(2) # ratio metric
  12.  
    best = x.max(1)
  13.  
    f = (best * (best > thr).astype(np.float32)).mean() # fitness
  14.  
    bpr = (best > thr).astype(np.float32).mean() # best possible recall
  15.  
    return f, bpr
  16.  
     
  17.  
     
  18.  
    def auto_anchor(img_size, n, thr, gen, img_wh, bbox_wh):
  19.  
    """
  20.  
    输入:img_size:图片缩放的大小
  21.  
    n:聚类数
  22.  
    thr:fitness的阈值
  23.  
    gen:遗传算法迭代次数
  24.  
    img_wh:图片的长宽集合
  25.  
    bbox_wh:bbox的长框集合
  26.  
    """
  27.  
    # 最大边缩放到img_size
  28.  
    img_wh = np.array(img_wh, dtype=np.float32)
  29.  
    shapes = (img_size * img_wh / img_wh).max(1, keepdims=True)
  30.  
    wh0 = np.concatenate([l * s for s, l in zip(shapes, bbox_wh)]) # wh
  31.  
     
  32.  
    i = (wh0 < 3.0).any(1).sum()
  33.  
    if i:
  34.  
    print(f'WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
  35.  
    wh = wh0[(wh0 >= 2.0).any(1)] # 只保留wh都大于等于2个像素的box
  36.  
    # k_means 聚类计算anchor
  37.  
    k = k_means(wh, n, use_iou=True, use_pp=False)
  38.  
    k = k[np.argsort(k.prod(1))] # sort small to large
  39.  
    f, bpr = anchor_fitness(k, wh, thr)
  40.  
    print("kmeans: " " ".join([f"[{int(i[0])}, {int(i[1])}]" for i in k]))
  41.  
    print(f"fitness: {f:.5f}, best possible recall: {bpr:.5f}")
  42.  
     
  43.  
    # YOLOV5改进遗传算法
  44.  
    npr = np.random
  45.  
    f, sh, mp, s = anchor_fitness(k, wh, thr)[0], k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  46.  
    pbar = tqdm(range(gen), desc=f'Evolving anchors with Genetic Algorithm:') # progress bar
  47.  
    for _ in pbar:
  48.  
    v = np.ones(sh)
  49.  
    while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  50.  
    v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s 1).clip(0.3, 3.0)
  51.  
    kg = (k.copy() * v).clip(min=2.0)
  52.  
    fg, bpr = anchor_fitness(kg, wh, thr)
  53.  
    if fg > f:
  54.  
    f, k = fg, kg.copy()
  55.  
    pbar.desc = f'Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  56.  
     
  57.  
    # 按面积排序
  58.  
    k = k[np.argsort(k.prod(1))] # sort small to large
  59.  
    print("genetic: " " ".join([f"[{int(i[0])}, {int(i[1])}]" for i in k]))
  60.  
    print(f"fitness: {f:.5f}, best possible recall: {bpr:.5f}")
  61.  
     
  62.  
    auto_anchor(img_size=416, n=9, thr=0.25, gen=1000, img_wh=img_wh, bbox_wh=bbox_wh)
学新通

如果有兴趣代码细节的,可以看里面的注释,如果还有不懂的,可以私信我交流。

最后计算出来的结果如下,可以看到计算出来的anchor的是长方型的,这是因为我的bbox中长方型的anchor居多,符合我的预期。我们只需要把下面的anchor,替换掉默认的anchor即可!

学新通

 学新通学新通

最后说明一下,用聚类算法算出来的anchor并不一定比初始值即coco上的anchor要好,原因是目标检测大部分基于迁移学习,backbone网络的训练参数是基于coco上的anchor学习的,所以其实大部分情况用这个聚类效果并没有直接使用coco上的好!!,而且聚类效果跟数据集的数量有很大关系,一两千张图片,聚类出来效果可能不会很好



总结

整个算法思路其实不难,但代码有一些冗余和长,主要也是结合自己在学习和使用过程中,发现很多博主没有说明白如何使用这些代码。

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhghfgib
系列文章
更多 icon
同类精品
更多 icon
继续加载