• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

CutMix原理和代码解读

武飞扬头像
00000cj
帮助1

paper:CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features

前言

之前的数据增强方法存在的问题:

mixup:混合后的图像在局部是模糊和不自然的,因此会混淆模型,尤其是在定位方面。

cutout:被cutout的部分通常用0或者随机噪声填充,这就导致在训练过程中这部分的信息被浪费掉了。

cutmix在cutout的基础上进行改进,cutout的部分用另一张图像上cutout的部分进行填充,这样即保留了cutout的优点:让模型从目标的部分视图去学习目标的特征,让模型更关注那些less discriminative的部分。同时比cutout更高效,cutout的部分用另一张图像的部分进行填充,让模型同时学习两个目标的特征。

从下图可以看出,虽然Mixup和Cutout都提升了模型的分类精度,但在若监督定位和目标检测性能上都有不同程度的下降,而CutMix则在各个任务上都获得了显著的性能提升。

学新通

CutMix

cutmix的具体过程如下

学新通

其中\(M\in\left \{ 0,1 \right \}^{W\times H}\)是一个binary mask表明从两张图中裁剪的patch的位置,和mixup一样,\(\lambda\)也是通过\(\beta(\alpha, \alpha)\)分布得到的,在文章中作者设置\(\alpha=1\),因此\(\lambda\)是从均匀分布\((0,1)\)中采样的。

为了得到mask,首先要确定cutmix的bounding box的坐标\(B=(r_{x},r_{y},r_{w},r_{h})\),其值通过下式得到

学新通

即 \(\lambda\) 确定了patch与原图的面积比,即A图cutout的面积越大,标签融合时A图的比例越小。

代码实现

下面是torchvision的官方实现

  1.  
    class RandomCutmix(torch.nn.Module):
  2.  
    """Randomly apply Cutmix to the provided batch and targets.
  3.  
    The class implements the data augmentations as described in the paper
  4.  
    `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
  5.  
    <https://arxiv.org/abs/1905.04899>`_.
  6.  
     
  7.  
    Args:
  8.  
    num_classes (int): number of classes used for one-hot encoding.
  9.  
    p (float): probability of the batch being transformed. Default value is 0.5.
  10.  
    alpha (float): hyperparameter of the Beta distribution used for cutmix.
  11.  
    Default value is 1.0.
  12.  
    inplace (bool): boolean to make this transform inplace. Default set to False.
  13.  
    """
  14.  
     
  15.  
    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
  16.  
    super().__init__()
  17.  
    if num_classes < 1:
  18.  
    raise ValueError("Please provide a valid positive value for the num_classes.")
  19.  
    if alpha <= 0:
  20.  
    raise ValueError("Alpha param can't be zero.")
  21.  
     
  22.  
    self.num_classes = num_classes
  23.  
    self.p = p
  24.  
    self.alpha = alpha
  25.  
    self.inplace = inplace
  26.  
     
  27.  
    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
  28.  
    """
  29.  
    Args:
  30.  
    batch (Tensor): Float tensor of size (B, C, H, W)
  31.  
    target (Tensor): Integer tensor of size (B, )
  32.  
     
  33.  
    Returns:
  34.  
    Tensor: Randomly transformed batch.
  35.  
    """
  36.  
    if batch.ndim != 4:
  37.  
    raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
  38.  
    if target.ndim != 1:
  39.  
    raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
  40.  
    if not batch.is_floating_point():
  41.  
    raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
  42.  
    if target.dtype != torch.int64:
  43.  
    raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
  44.  
     
  45.  
    if not self.inplace:
  46.  
    batch = batch.clone()
  47.  
    target = target.clone()
  48.  
     
  49.  
    if target.ndim == 1:
  50.  
    target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
  51.  
     
  52.  
    if torch.rand(1).item() >= self.p:
  53.  
    return batch, target
  54.  
     
  55.  
    # It's faster to roll the batch by one instead of shuffling it to create image pairs
  56.  
    batch_rolled = batch.roll(1, 0)
  57.  
    target_rolled = target.roll(1, 0)
  58.  
     
  59.  
    # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
  60.  
    lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
  61.  
    _, H, W = F.get_dimensions(batch)
  62.  
     
  63.  
    r_x = torch.randint(W, (1,))
  64.  
    r_y = torch.randint(H, (1,))
  65.  
     
  66.  
    r = 0.5 * math.sqrt(1.0 - lambda_param)
  67.  
    r_w_half = int(r * W)
  68.  
    r_h_half = int(r * H)
  69.  
     
  70.  
    x1 = int(torch.clamp(r_x - r_w_half, min=0))
  71.  
    y1 = int(torch.clamp(r_y - r_h_half, min=0))
  72.  
    x2 = int(torch.clamp(r_x r_w_half, max=W))
  73.  
    y2 = int(torch.clamp(r_y r_h_half, max=H))
  74.  
     
  75.  
    batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
  76.  
    lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
  77.  
     
  78.  
    target_rolled.mul_(1.0 - lambda_param)
  79.  
    target.mul_(lambda_param).add_(target_rolled)
  80.  
     
  81.  
    return batch, target
  82.  
     
  83.  
    def __repr__(self) -> str:
  84.  
    s = (
  85.  
    f"{self.__class__.__name__}("
  86.  
    f"num_classes={self.num_classes}"
  87.  
    f", p={self.p}"
  88.  
    f", alpha={self.alpha}"
  89.  
    f", inplace={self.inplace}"
  90.  
    f")"
  91.  
    )
  92.  
    return s
学新通

实验结果

从下图可以看出,CutMix在ImageNet上的精度超过了Cutout和Mixup等数据增强方法

学新通

在若监督目标定位方面,CutMix也超过了Mixup和Cutout

学新通

当作为预训练模型迁移到其它下游任务比如目标检测和图像描述时,CutMix也取得了最好的效果

学新通

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhfiijai
系列文章
更多 icon
同类精品
更多 icon
继续加载