【庖丁解牛】从零实现RetinaNet(终):更快的后处理、不同分辨率下RetinaNet的最终表现

    技术2024-07-28  78

    文章目录

    与论文分辨率的对标方式更快的后处理方式训练结果更高分辨率下的RetinaNet性能表现

    所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training 如果觉得有用,请点个star哟! 代码均在pytorch1.4版本中测试过,确认正确无误。

    与论文分辨率的对标方式

    之前有同学提出从零实现RetinaNet(二)中与原论文点数对标的换算方式不太合理,应该按照最终输入网络的图片大小来换算对标分辨率,因为只要最终输入的网络图片大小一样,那么计算量(这里指FLOPS)就完全一样。如果按照这种换算方式,那么我的resize方法与RetinaNet原始resize方法对标分辨率如下:

    retinanet_resize=400,my_resize=667 retinanet_resize=500,my_resize=833 retinanet_resize=600,my_resize=1000 retinanet_resize=700,my_resize=1166 retinanet_resize=800,my_resize=1333

    更快的后处理方式

    通过实验发现前面的后处理方式速度上有些慢,所以进行了一些优化。首先每个fpn层级上的Anchor按照分类置信度只保留前1000个Anchor。然后按照0.05的分类置信度阈值进行过滤。后面的NMS算法不再区分预测的类别,而是所有类别的预测框一起做NMS。另外,此时的NMS只需要保留max_detection_num个框就可以停止了,因为不区分类别,前max_detection_num个框一定是分类置信度最大的top max_detection_num个框,所以NMS之后也不需要再排序了。 完整decode.py代码实现如下:

    import torch import torch.nn as nn class RetinaDecoder(nn.Module): def __init__(self, image_w, image_h, top_n=1000, min_score_threshold=0.01, nms_threshold=0.5, max_detection_num=100): super(RetinaDecoder, self).__init__() self.image_w = image_w self.image_h = image_h self.top_n=1000 self.min_score_threshold = min_score_threshold self.nms_threshold = nms_threshold self.max_detection_num = max_detection_num def forward(self, cls_heads, reg_heads, batch_anchors): with torch.no_grad(): device = cls_heads[0].device filter_scores,filter_score_classes,filter_reg_heads,filter_batch_anchors=[],[],[],[] for per_level_cls_head,per_level_reg_head,per_level_anchor in zip(cls_heads, reg_heads, batch_anchors): scores, score_classes = torch.max(per_level_cls_head, dim=2) if scores.shape[1]>=1000: scores,indexes=torch.topk(scores, 1000, dim=1, largest=True, sorted=True) score_classes=torch.gather(score_classes, 1, indexes) per_level_reg_head=torch.gather(per_level_reg_head,1,indexes.unsqueeze(-1).repeat(1,1,4)) per_level_anchor =torch.gather(per_level_anchor,1,indexes.unsqueeze(-1).repeat(1,1,4)) filter_scores.append(scores) filter_score_classes.append(score_classes) filter_reg_heads.append(per_level_reg_head) filter_batch_anchors.append(per_level_anchor ) filter_scores = torch.cat(filter_scores, axis=1) filter_score_classes = torch.cat(filter_score_classes, axis=1) filter_reg_heads = torch.cat(filter_reg_heads, axis=1) filter_batch_anchors = torch.cat(filter_batch_anchors, axis=1) batch_scores, batch_classes, batch_pred_bboxes = [], [], [] for per_image_scores,per_image_score_classes, per_image_reg_heads, per_image_anchors in zip( filter_scores, filter_score_classes,filter_reg_heads, filter_batch_anchors): pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes( per_image_reg_heads, per_image_anchors) score_classes = per_image_score_classes[ per_image_scores > self.min_score_threshold].float() pred_bboxes = pred_bboxes[ per_image_scores > self.min_score_threshold].float() scores = per_image_scores[per_image_scores > self.min_score_threshold].float() sorted_keep_scores, sorted_keep_classes, sorted_keep_pred_bboxes = self.nms( scores, score_classes, pred_bboxes) sorted_keep_scores = sorted_keep_scores.unsqueeze(0) sorted_keep_classes = sorted_keep_classes.unsqueeze(0) sorted_keep_pred_bboxes = sorted_keep_pred_bboxes.unsqueeze( 0) batch_scores.append(sorted_keep_scores) batch_classes.append(sorted_keep_classes) batch_pred_bboxes.append(sorted_keep_pred_bboxes) batch_scores = torch.cat(batch_scores, axis=0) batch_classes = torch.cat(batch_classes, axis=0) batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0) # batch_scores shape:[batch_size,max_detection_num] # batch_classes shape:[batch_size,max_detection_num] # batch_pred_bboxes shape[batch_size,max_detection_num,4] return batch_scores, batch_classes, batch_pred_bboxes def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes): """ one_image_scores:[anchor_nums],4:classification predict scores one_image_classes:[anchor_nums],class indexes for predict scores one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max """ device=one_image_scores.device final_scores = (-1) * torch.ones( (self.max_detection_num, ), device=device) final_classes = (-1) * torch.ones( (self.max_detection_num, ), device=device) final_pred_bboxes = (-1) * torch.ones( (self.max_detection_num, 4), device=device) if one_image_scores.shape[0]==0: return final_scores, final_classes, final_pred_bboxes # Sort boxes sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort( one_image_scores, descending=True) sorted_one_image_classes = one_image_classes[ sorted_one_image_scores_indexes] sorted_one_image_pred_bboxes = one_image_pred_bboxes[ sorted_one_image_scores_indexes] sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:, 2:] - sorted_one_image_pred_bboxes[:, : 2] sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:, 0] * sorted_pred_bboxes_w_h[:, 1] keep_scores, keep_classes, keep_pred_bboxes = [], [], [] while sorted_one_image_scores.numel() > 0: top1_score, top1_class, top1_pred_bbox = sorted_one_image_scores[ 0:1], sorted_one_image_classes[0:1], sorted_one_image_pred_bboxes[0:1] keep_scores.append(top1_score) keep_classes.append(top1_class) keep_pred_bboxes.append(top1_pred_bbox) top1_areas = sorted_pred_bboxes_areas[0] if len(keep_scores)>=self.max_detection_num: break if sorted_one_image_scores.numel() == 1: break sorted_one_image_scores = sorted_one_image_scores[1:] sorted_one_image_classes = sorted_one_image_classes[1:] sorted_one_image_pred_bboxes = sorted_one_image_pred_bboxes[1:] sorted_pred_bboxes_areas = sorted_pred_bboxes_areas[ 1:] overlap_area_top_left = torch.max( sorted_one_image_pred_bboxes[:, :2], top1_pred_bbox[:, :2]) overlap_area_bot_right = torch.min( sorted_one_image_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:]) overlap_area_sizes = torch.clamp(overlap_area_bot_right - overlap_area_top_left, min=0) overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1] # compute union_area union_area = top1_areas + sorted_pred_bboxes_areas - overlap_area union_area = torch.clamp(union_area, min=1e-4) # compute ious for top1 pred_bbox and the other pred_bboxes ious = overlap_area / union_area sorted_one_image_scores = sorted_one_image_scores[ ious < self.nms_threshold] sorted_one_image_classes = sorted_one_image_classes[ious < self.nms_threshold] sorted_one_image_pred_bboxes = sorted_one_image_pred_bboxes[ ious < self.nms_threshold] sorted_pred_bboxes_areas = sorted_pred_bboxes_areas[ ious < self.nms_threshold] keep_scores = torch.cat(keep_scores, axis=0) keep_classes = torch.cat(keep_classes, axis=0) keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0) final_detection_num = min(self.max_detection_num, keep_scores.shape[0]) final_scores[ 0:final_detection_num] = keep_scores[ 0:final_detection_num] final_classes[ 0:final_detection_num] = keep_classes[ 0:final_detection_num] final_pred_bboxes[ 0:final_detection_num, :] = keep_pred_bboxes[ 0:final_detection_num, :] return final_scores, final_classes, final_pred_bboxes def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes( self, reg_heads, anchors): """ snap reg heads to pred bboxes reg_heads:[anchor_nums,4],4:[tx,ty,tw,th] anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max] """ anchors_wh = anchors[:, 2:] - anchors[:, :2] anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh device = anchors.device factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device) reg_heads = reg_heads * factor pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh pred_bboxes = torch.cat( [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1) pred_bboxes = pred_bboxes.int() pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0) pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0) pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2], max=self.image_w - 1) pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3], max=self.image_h - 1) # pred bboxes shape:[anchor_nums,4] return pred_bboxes if __name__ == '__main__': from retinanet import RetinaNet net = RetinaNet(resnet_type="resnet50") image_h, image_w = 640, 640 cls_heads, reg_heads, batch_anchors = net( torch.autograd.Variable(torch.randn(3, 3, image_h, image_w))) annotations = torch.FloatTensor([[[113, 120, 183, 255, 5], [13, 45, 175, 210, 2]], [[11, 18, 223, 225, 1], [-1, -1, -1, -1, -1]], [[-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]]]) decode = RetinaDecoder(image_w, image_h) batch_scores, batch_classes, batch_pred_bboxes = decode( cls_heads, reg_heads, batch_anchors) print("1111", batch_scores.shape, batch_classes.shape, batch_pred_bboxes.shape)

    训练结果

    根据新的分辨率对标方式,我使用my_resize=667和1000在COCO数据集上分别训练了不同分辨率下的RetinaNet。除了下表列出的超参数外,模型的其他超参数设置与从零实现RetinaNet(七)中ResNet50-RetinaNet-aug-iscrowd的参数设置完全一样。mAP为COCOeval stats[0]值,mAR为COCOeval stats[8]值。 模型表现如下:

    Networkbatchgpu-numapexsyncbnepoch5-mAP-mAR-lossepoch10-mAP-mAR-lossepoch12-mAP-mAR-lossResNet50-RetinaNet-myresize667242yesno0.264,_,0.610.298,_,0.510.302,_,0.49ResNet50-RetinaNet-myresize667-fastdecode242yesno0.253,0.361,0.610.287,0.398,0.510.293,0.401,0.49ResNet101-RetinaNet-myresize667-fastdecode162yesno0.254,0.362,0.600.290,0.398,0.510.296,0.402,0.48

    可以看到如果采用这种方式对标,我实现的RetinaNet和论文中的点数基本一致,在论文中resize=400的情况下只比论文报告点数(0.305)低了0.3个百分点。采用新的后处理方式后,mAP降低了不到1个百分点,但速度上要快的多。

    更高分辨率下的RetinaNet性能表现

    我又尝试了使用resize=1000(相当于RetinaNet论文中的600分辨率)来训练RetinaNet。如果直接从头开始训练,在高分辨率输入下网络的收敛速度会变慢,在12个epoch时网络的性能表现还不如ResNet50-RetinaNet-myresize667-fastdecode。因此我又尝试使用ResNet50-RetinaNet-myresize667-fastdecode在12个epoch保存的模型参数来初始化RetinaNet网络,再在resize=1000分辨率下继续训练,结果如下:

    Networkepoch5-mAP-mAR-lossepoch10-mAP-mAR-lossepoch12-mAP-mAR-lossepoch15-mAP-mAR-lossepoch20-mAP-mAR-lossepoch24-mAP-mAR-lossResNet50-RetinaNet-myresize10000.305,0.425,0.550.306,0.429,0.550.333,0.456,0.460.337,0.460,0.450.339,0.459,0.430.339,0.460,0.42

    训练时使用4张2080ti,总batch=16,使用apex但未使用syncbn。注意学习率从第10个epoch结束时自动衰减为1e-5,因此从epoch10到epoch12 loss值出现较大下降。最后训练mAP结果为0.339,与RetinaNet论文报告点数0.343相差0.4个百分点,还是很接近的。考虑到后面12个epoch训练的增益不大,其实只要训练12个epoch就可以了。

    Processed: 0.016, SQL: 9