目标检测知识蒸馏---以SSD为例
知识蒸馏是一种不改变网络结构模型压缩方法。这里的压缩需要和量化与剪枝进行区分,并不是严格意义上的压缩。这里将要讲的蒸馏是离线式蒸馏中的逻辑蒸馏【特征部分蒸馏以后会讲】,也是一种常用的方法,他是将已经训练好的teacher model对student model进行蒸馏。
teacher model是一个在精度表现上优良的模型,而student model往往是精度差一些,但推理速度高的模型。如果要采用这种蒸馏方式,需要注意的是两个Model的网络结构需要相似【因此可以将改进前后的model建立这种关系】。而实现部分最最重要的部分是建立蒸馏的Loss函数。
目录
在目标检测中主要有两个任务,一个是分类,一个是边界的回归,前者的蒸馏是比较容易的,关键是在后者,这也是蒸馏的一个难点。
我们先来看一下SSD代码中的MultiBoxloss部分详解。
MultiBoxloss
SSD中分类loss采用CrossEntropy,边界loss采用平滑L1。具体公式和网络算法原理参考论文,这里不在多说。
loss参数说明
参数说明:
self.use_gpu:是否采用gpu训练
self.num_classes:训练类的数量【在SSD中num_classes是自己的类数量 背景类】
self.threshold:阈值,默认0.5
self.background_label:背景类标签,默认为0
self.encode_target:target编码
self.use_prior_for_matching:利用先眼眶做匹配
self.do_neg_mining:True,负样本挖掘
self.negpos_ratio:负样本比例,设置为3【正负样本比例为1:3】
self.variance :方差
-
class MultiBoxLoss(nn.Module):
-
"""SSD Weighted Loss Function
-
Compute Targets:
-
1) Produce Confidence Target Indices by matching ground truth boxes
-
with (default) 'priorboxes' that have jaccard index > threshold parameter
-
(default threshold: 0.5).
-
2) Produce localization target by 'encoding' variance into offsets of ground
-
truth boxes and their matched 'priorboxes'.
-
3) Hard negative mining to filter the excessive number of negative examples
-
that comes with using a large number of default bounding boxes.
-
(default negative:positive ratio 3:1)
-
Objective Loss:
-
L(x,c,l,g) = (Lconf(x, c) αLloc(x,l,g)) / N
-
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
-
weighted by α which is set to 1 by cross val.
-
Args:
-
c: class confidences,
-
l: predicted boxes,
-
g: ground truth boxes
-
N: number of matched default boxes
-
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
-
"""
-
-
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
-
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
-
use_gpu=True, negatives_for_hard=100.0):
-
super(MultiBoxLoss, self).__init__()
-
self.use_gpu = use_gpu
-
self.num_classes = num_classes
-
self.threshold = overlap_thresh
-
self.background_label = bkg_label
-
self.encode_target = encode_target
-
self.use_prior_for_matching = prior_for_matching
-
self.do_neg_mining = neg_mining
-
self.negpos_ratio = neg_pos
-
self.neg_overlap = neg_overlap
-
self.negatives_for_hard = negatives_for_hard
-
self.variance = Config['variance']
loss forward部分
predictions:类型为tuple,网络的输出内容,包含:位置预测,分类置信度预测以及prior boxes预测。
predictions[0]的shape为:[batch,8732,4]
predictions[1]的shape为:[batch,8732,num_classes]
predictions[2]的shape为:[8732,4]
注:8732:以输入大小300*300为例,将在6个head部分产生8732个先眼眶.
8732= 38*38*4 19*19*6 10*10*6 5*5*6 3*3*6 1*1*4
target:包含了标注的数据集真实的boxes坐标以及label信息。是一个列表,列表的长度等于batch的数量,每个列表中的元素shape为[num_objs,5],num_objs表示你当前图像中标注的目标数量,5=boxes信息 label信息。
-
def forward(self, predictions, targets):
-
"""Multibox Loss
-
Args:
-
predictions (tuple): A tuple containing loc preds, conf preds,
-
and prior boxes from SSD net.
-
conf shape: torch.size(batch_size,num_priors,num_classes)
-
loc shape: torch.size(batch_size,num_priors,4)
-
priors shape: torch.size(num_priors,4)
-
pred_t (tuple): teacher's predictions
-
targets (tensor): Ground truth boxes and labels for a batch,
-
shape: [batch_size,num_objs,5] (last idx is the label).
-
"""
-
-
#--------------------------------------------------#
-
# 取出预测结果的三个值:回归信息,置信度,先验框
-
#--------------------------------------------------#
-
loc_data, conf_data, priors = predictions
创建两个全零张量用来做先验框和真实框的匹配,这里的num等于batch_size
-
loc_t = torch.zeros(num, num_priors, 4).type(torch.FloatTensor)
-
conf_t = torch.zeros(num, num_priors).long()
遍历每个batch,idx是batch的索引,truths是获取到的真实值的boxes信息。labels是获取到的当前图像中是什么类。
truths:tensor([[0.2333, 0.2067, 0.6967, 1.0000]], device='cuda:0')
labels:tensor([0.], device='cuda:0')
-
for idx in range(num):
-
# 获得真实框与标签
-
truths = targets[idx][:, :-1]
-
labels = targets[idx][:, -1]
-
-
if(len(truths)==0):
-
continue
-
-
# 获得先验框
-
defaults = priors
-
#--------------------------------------------------#
-
# 利用真实框和先验框进行匹配。
-
# 如果真实框和先验框的重合度较高,则认为匹配上了。
-
# 该先验框用于负责检测出该真实框。
-
#--------------------------------------------------#
-
match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)
defaults是先验框,接下来是和真实框进行标签匹配。
标签匹配函数match
传入参数:threshold,truths[真实boxes],defaults[先验框],variance[方差],labels[真实标签],loc_t[前面创建的全零张量,用来存放匹配后的boxes信息], conf_t[用来存储匹配后的置信度分类信息],idx[当前batch的索引 ]。
1.先计算所有先验框和真实框的的重合程度。
box_a是就是上面的truths,box_b是先验框【注意先验框中的boxes形式是center_x,center_y,w,h,需要先转成左上角和右下角的形式】。最终就可以计算出IOU。
-
def jaccard(box_a, box_b):
-
#-------------------------------------#
-
# 返回的inter的shape为[A,B]
-
# 代表每一个真实框和先验框的交矩形
-
#-------------------------------------#
-
inter = intersect(box_a, box_b)
-
#-------------------------------------#
-
# 计算先验框和真实框各自的面积
-
#-------------------------------------#
-
area_a = ((box_a[:, 2]-box_a[:, 0]) *
-
(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
-
area_b = ((box_b[:, 2]-box_b[:, 0]) *
-
(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
-
-
union = area_a area_b - inter
-
#-------------------------------------#
-
# 每一个真实框和先验框的交并比[A,B]
-
#-------------------------------------#
-
return inter / union
因此得到的overlaps是计算的所有先验框和真实框的iou,shape为[1,8732]。
-
overlaps = jaccard(
-
truths,
-
point_form(priors)
-
)
接下来是通过max函数获得这8732个先验框中与真实框匹配度最好的框和索引【就相当于可以把这个匹配的最好的认为是ground truth】。
可以得到iou最高的是0.6904,是第8711号先验框。
best_prior_overlap:tensor([[0.6904]], device='cuda:0')
best_prior_idx:tensor([[8711]], device='cuda:0')
用于保证每个真实框都有一个先验框与之匹配。
-
for j in range(best_prior_idx.size(0)):
-
best_truth_idx[best_prior_idx[j]] = j
-
best_truth_overlap.index_fill_(0, best_prior_idx, 2)
将truths扩充成8732.
matches = truths[best_truth_idx]
获取标签
conf = labels[best_truth_idx] 1
获取背景类,通过设置的iou阈值进行过滤。
conf[best_truth_overlap < threshold] = 0
进行边界框的编码【其实就是将真实框和先验框进行匹配】。
-
def encode(matched, priors, variances):
-
g_cxcy = (matched[:, :2] matched[:, 2:])/2 - priors[:, :2]
-
g_cxcy /= (variances[0] * priors[:, 2:])
-
-
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
-
g_wh = torch.log(g_wh) / variances[1]
-
return torch.cat([g_cxcy, g_wh], 1)
获得的loc shape为【8732,4】
loc = encode(matches, priors, variances)
将编码后的loc放入前面定义loc_t中,conf也是如此。
获得正样本。
-
# 所有conf_t>0的地方,代表内部包含物体
-
pos = conf_t > 0
此时的pos形式如下,shape为【batch,8732】:
tensor([[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]], device='cuda:0')
求和得到每个图像内有多少正样本。这就可以计算出在所有的batch中的所有batch*8732个框中有多少框内包含目标。
num_pos = pos.sum(dim=1, keepdim=True)
loss计算
取出所有的正样本计算loss
获得所有正样本的idx,返回形式是Truth or False.
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
通过索引在loc_data[预测的位置]中选择出正样本的loc_p【也就是预测目标的loc】。
loc_p = loc_data[pos_idx].view(-1, 4)
通过正样本的索引在loc_t【groud truth】中进行筛选获得正样本的loc_t。
loc_t = loc_t[pos_idx].view(-1, 4)
计算边界回归loss:
直接调用smooth_l1_loss计算loss【loc_p是预测值,loc_t是真实值】
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
分类loss:
获得网络预测的conf,进行reshape,这就获得了所有batch中预测框内的conf,shape为【batch*8732,num_classes】。
batch_conf = conf_data.view(-1, self.num_classes)
conf_p是预测值【筛选后具有正样本的】,
-
# 这个地方是在寻找难分类的先验框
-
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
-
loss_c = loss_c.view(num, -1)
-
-
# 难分类的先验框不把正样本考虑进去,只考虑难分类的负样本
-
loss_c[pos] = 0
-
#--------------------------------------------------#
-
# loss_idx (num, num_priors)
-
# idx_rank (num, num_priors)
-
#--------------------------------------------------#
-
_, loss_idx = loss_c.sort(1, descending=True)
-
_, idx_rank = loss_idx.sort(1)
-
#--------------------------------------------------#
-
# 求和得到每一个图片内部有多少正样本
-
# num_pos (num, )
-
# neg (num, num_priors)
-
#--------------------------------------------------#
-
num_pos = pos.long().sum(1, keepdim=True)
-
# 限制负样本数量
-
num_neg = torch.clamp(self.negpos_ratio * num_pos, max = pos.size(1) - 1)
-
num_neg[num_neg.eq(0)] = self.negatives_for_hard
-
neg = idx_rank < num_neg.expand_as(idx_rank)
-
-
#--------------------------------------------------#
-
# 求和得到每一个图片内部有多少正样本
-
# pos_idx (num, num_priors, num_classes)
-
# neg_idx (num, num_priors, num_classes)
-
#--------------------------------------------------#
-
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
-
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
-
-
# 选取出用于训练的正样本与负样本,计算loss
-
conf_p = conf_data[(pos_idx neg_idx).gt(0)].view(-1, self.num_classes)
-
targets_weighted = conf_t[(pos neg).gt(0)]
-
loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
最后总的Loss为:
loss:8.0996
MultiBoxloss_KD
在原来的loss基础上加入了soft-target loss部分。
-
class MultiBoxLoss_KD(nn.Module):
-
-
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
-
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
-
use_gpu=True, negatives_for_hard=100.0,neg_w=1.5, pos_w=1.0, Temp=1., reg_m=0.):
-
super(MultiBoxLoss_KD, self).__init__()
-
self.use_gpu = use_gpu
-
self.num_classes = num_classes # 21
-
self.threshold = overlap_thresh # 0.5
-
self.background_label = bkg_label # 0
-
self.encode_target = encode_target # False
-
self.use_prior_for_matching = prior_for_matching # True
-
self.do_neg_mining = neg_mining # True
-
self.negpos_ratio = neg_pos # 3
-
self.neg_overlap = neg_overlap # 0.5
-
self.variance = Config['variance']
-
self.negatives_for_hard = negatives_for_hard
-
-
# soft-target loss
-
self.neg_w = neg_w # 负样本(背景)权重
-
self.pos_w = pos_w # 正样本权重
-
self.Temp = Temp # 温度
-
self.reg_m = reg_m
在forward部分传入参数为predictions[student的输出,pred_t为teacher的输出,targets是真实值]
-
def forward(self, predictions, pred_t, targets):
-
"""Multibox Loss
-
Args:
-
predictions (tuple): A tuple containing loc preds, conf preds,
-
and prior boxes from SSD net.
-
conf shape: torch.size(batch_size,num_priors,num_classes)
-
loc shape: torch.size(batch_size,num_priors,4)
-
priors shape: torch.size(num_priors,4)
-
pred_t (tuple): teacher's predictions
-
targets (tensor): Ground truth boxes and labels for a batch,
-
shape: [batch_size,num_objs,5] (last idx is the label).
-
"""
kd for loc regression
这里的loc regression采用的是l2 loss.
-
# teach1 这里的s指student,t指真实值
-
loc_teach1_p = loc_teach1[pos_idx].view(-1, 4) # loc_teach1_p = tensor<(3, 4), float32, cuda:0, grad>
-
l2_dis_s = (loc_p - loc_t).pow(2).sum(1) # Σ(loc_p-loc_t)² 计算学生L2 loss,(学生预测loc-真实标签)² sum(1)求行和 l2_dis_s = tensor<(3,), float32, cuda:0, grad>
-
l2_dis_s_m = l2_dis_s self.reg_m # l2_dis_s_m = tensor<(3,), float32, cuda:0, grad>
-
l2_dis_t = (loc_teach1_p - loc_t).pow(2).sum(1) # L2 loss:(老师loc预测值-真实标签)²并求和 l2_dis_t = tensor<(3,), float32, cuda:0, grad>
-
l2_num = l2_dis_s_m > l2_dis_t # 判断学生位置回归与真实reg距离 和 老师位置回归与真实标签距离 的大小 l2_num = tensor<(3,), bool, cuda:0>
-
l2_loss_teach1 = l2_dis_s[l2_num].sum() # 当学生大于老师 Lb(Rs,Rt,y)=Σ(loc_p-loc_t)²,否则为0 Lb表示文章定义的teacher bounded regression loss
-
# l2_loss_teach1 = tensor<(), float32, cuda:0, grad> 取出l2_num=True的
-
l2_loss = l2_loss_teach1 # l2_loss = tensor<(), float32, cuda:0, grad>
kd for conf regression
conf_p是预测值,ps是student的分类预测,pt是teacher的分类预测,计算两者loss。
-
# soft loss for Knowledge Distillation
-
# teach1
-
conf_p_teach = conf_teach1[(pos_idx neg_idx).gt(0)].view(-1, self.num_classes)
-
pt = F.softmax(conf_p_teach / self.Temp, dim=1)
-
if self.neg_w > 1.:
-
ps = F.softmax(conf_p / self.Temp, dim=1)
-
soft_loss1 = KL_div(pt, ps, self.pos_w, self.neg_w) * (self.Temp ** 2)
-
else:
-
ps = F.log_softmax(conf_p / self.Temp, dim=1)
-
soft_loss1 = nn.KLDivLoss(size_average=False)(ps, pt) * (self.Temp ** 2)
-
soft_loss = soft_loss1
最后返回有4个loss,soft_loss是分类kd loss,l2loss是loc 蒸馏loss,loss_c, loss_l均为hard loss中的student自己的loss。
-
loss_l /= N
-
loss_c /= N
-
l2_loss /= N
-
soft_loss /= N
-
return soft_loss, l2_loss, loss_c, loss_l
然后可以根据自己的情况给不同的loss分配不同的权重进行训练。
-
soft_loss, l2_loss, loss_c, loss_l = criterion(out_student, teacher_out, targets) # KD损失函数
-
# loss_l, loss_c = criterion1(out_student, targets) # criterion1原损失函数
-
-
loss = (0.3 * soft_loss 0.7 * loss_c) (0.5 * l2_loss loss_l)
训练如下(这里为了方便演示我这里只放了100张图片训练):
-
Epoch 9/10: 100%|██████████| 54/54 [00:19<00:00, 2.80it/s, conf_loss=2.37, loc_loss=0.752, lr=7.16e-5]
-
Start Teacher Validation
-
Epoch 9/10: 100%|██████████| 6/6 [00:02<00:00, 2.01it/s, conf_loss=2.24, loc_loss=0.669, lr=7.16e-5]
-
Finish Teacher Validation
-
Epoch:9/10
-
Total Loss: 3.0658 || Val Loss: 2.4932
-
Saving state, iter: 9
-
Start Teacher Train
-
Epoch 10/10: 100%|██████████| 54/54 [00:19<00:00, 2.79it/s, conf_loss=2.28, loc_loss=0.73, lr=6.59e-5]
-
Epoch 10/10: 0%| | 0/6 [00:00<?, ?it/s<class 'dict'>]Start Teacher Validation
-
Epoch 10/10: 100%|██████████| 6/6 [00:03<00:00, 1.97it/s, conf_loss=2.14, loc_loss=0.691, lr=6.59e-5]
-
Finish Teacher Validation
-
Epoch:10/10
-
Total Loss: 2.9564 || Val Loss: 2.4282
-
Saving state, iter: 10
-
开始蒸馏训练
-
Loading weights into state dict...
-
Finished!
-
Epoch 1/2: 0%| | 0/54 [00:00<?, ?it/s<class 'dict'>]Start teacher2student_KD Train
-
Epoch 1/2: 100%|██████████| 54/54 [00:20<00:00, 2.60it/s, conf_loss=3.19, l2_loss=8.29, loc_loss=2.91, lr=0.0005, soft_loss=3.9]
-
Start Teacher2student_KD Validation
-
Epoch 1/2: 100%|██████████| 6/6 [00:02<00:00, 2.12it/s, conf_loss=2.75, loc_loss=2.56, lr=0.0005]
-
Finish teacher2student_KD Validation
-
Epoch:1/2
-
Total Loss: 17.9524 || Val Loss: 4.5457
-
Saving state, iter: 1
-
Start teacher2student_KD Train
-
Epoch 2/2: 100%|██████████| 54/54 [00:20<00:00, 2.58it/s, conf_loss=2.99, l2_loss=6.52, loc_loss=2.53, lr=0.00046, soft_loss=3.41]
-
Start Teacher2student_KD Validation
-
Epoch 2/2: 100%|██████████| 6/6 [00:02<00:00, 2.14it/s, conf_loss=2.65, loc_loss=2.68, lr=0.00046]
-
Finish teacher2student_KD Validation
-
Epoch:2/2
-
Total Loss: 15.1709 || Val Loss: 4.5685
-
Saving state, iter: 2
-
Epoch 3/4: 0%| | 0/54 [00:00<?, ?it/s<class 'dict'>]Start teacher2student_KD Train
-
Epoch 3/4: 100%|██████████| 54/54 [00:25<00:00, 2.12it/s, conf_loss=2.66, l2_loss=6.92, loc_loss=2.63, lr=0.0001, soft_loss=3.05]
-
Epoch 3/4: 0%| | 0/6 [00:00<?, ?it/s<class 'dict'>]Start Teacher2student_KD Validation
-
Epoch 3/4: 100%|██████████| 6/6 [00:02<00:00, 2.21it/s, conf_loss=2.39, loc_loss=2.66, lr=0.0001]
-
Finish teacher2student_KD Validation
-
Epoch:3/4
-
Total Loss: 14.9715 || Val Loss: 4.3286
-
Saving state, iter: 3
-
Start teacher2student_KD Train
-
Epoch 4/4: 100%|██████████| 54/54 [00:25<00:00, 2.12it/s, conf_loss=2.44, l2_loss=6.46, loc_loss=2.54, lr=9.2e-5, soft_loss=2.84]
-
Start Teacher2student_KD Validation
-
Epoch 4/4: 100%|██████████| 6/6 [00:02<00:00, 2.15it/s, conf_loss=2.38, loc_loss=2.57, lr=9.2e-5]
-
Finish teacher2student_KD Validation
-
Epoch:4/4
-
Total Loss: 14.0245 || Val Loss: 4.2435
-
Saving state, iter: 4
-
注:离线蒸馏训练对于teacher model也是有要求的,我这里的teacher model只是随便在原model的基础上改了一下训练而已,我这里仅仅是演示一下,具体的改进等需要自己去不断尝试。因此kd的好坏是取决于两个模型的。
大家也可以尝试其他的蒸馏方式,有问题可评论留言~~欢迎支持
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhggieig
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
excel打印预览压线压字怎么办
PHP中文网 06-22 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01 -
怎样阻止微信小程序自动打开
PHP中文网 06-13