Faster RCNN --> GeneralizedRCNN --> nn.Module
对于GeneralizedRCNN由四个部分组成
transformbackbonerpnroi_headstransform主要做两件事:
image: [0,255] uint8–> [-1.0, +1.0] float32image, targets : resize # GeneralizedRCNN.forward(...) for img in images: val = img.shape[-2:] assert len(val) == 2 original_image_sizes.append((val[0], val[1])) images, targets = self.transform(images, targets)完成图像缩放之后其实才算是正式进入网络流程。接下来有4个步骤:
将transform后的图像输入到backbone模块提取特征图
# GeneralizedRCNN.forward(...) features = self.backbone(images.tensors) #然后经过rpn 模块生成proposals 和proposal_losses
# GeneralizedRCNN.forward(...) proposals, proposal_losses = self.rpn(images, features, targets)接着进入roi_heads模块(即 roi_pooling+ 分类)
# GeneralizedRCNN.forward(...) detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)最后经transform.postprocess模块(进行 NMS,同时将box通过 original_images_size映射回原图)
# GeneralizedRCNN.forward(...) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)FasterRCNN 继承基类 GeneralizedRCNN, 实现了 GeneralizedRCNN 中的 transform、backbone、rpn、roi_heads 接口:
class FasterRCNN(GeneralizedRCNN): def __init__(self, backbone, num_classes=None, # transform parameters min_size=800, max_size=1333, image_mean=None, image_std=None, # RPN parameters rpn_anchor_generator=None, rpn_head=None, rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, rpn_nms_thresh=0.7, rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # Box parameters box_roi_pool=None, box_head=None, box_predictor=None, box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, box_batch_size_per_image=512, box_positive_fraction=0.25, bbox_reg_weights=None): out_channels = backbone.out_channels if rpn_anchor_generator is None: anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) rpn_anchor_generator = AnchorGenerator( anchor_sizes, aspect_ratios ) if rpn_head is None: rpn_head = RPNHead( out_channels, rpn_anchor_generator.num_anchors_per_location()[0] ) rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) rpn = RegionProposalNetwork( rpn_anchor_generator, rpn_head, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) roi_heads = RoIHeads( # Box box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh, box_batch_size_per_image, box_positive_fraction, bbox_reg_weights, box_score_thresh, box_nms_thresh, box_detections_per_img) if image_mean is None: image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform)对于 transform 接口,使用 GeneralizedRCNNTransform 实现。从代码变量名可以明显看到包含:
与缩放相关参数:min_size + max_size与归一化相关参数:image_mean + image_std # FasterRCNN.__init__(...) if image_mean is None: image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)对于 backbone, Faster RCNN 使用 ResNet50 + FPN 结构:
# detection/faster_rcnn.py (291...) def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, **kwargs): if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) model = FasterRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], progress=progress) model.load_state_dict(state_dict) return model特征提取网络,一般为 VGG、ResNet、MobileNet 等网络。
modelbackbonenickfaster RCNNResnet50FPNyolov3Darknet53ssdvgg接下来重点介绍 rpn 接口的实现
# FasterRCNN.__init__(194) rpn = RegionProposalNetwork( rpn_anchor_generator, rpn_head, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)首先是 rpn_anchor_generator :
目的: 用于生成Base-anchorinput: 每个位置有三种长宽比(0.5, 1.0, 2.0),以及五种sanchor_sizeoutput: 每个位置有15个base_anchor # FasterRCNN.__init__(...) if rpn_anchor_generator is None: anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) rpn_anchor_generator = AnchorGenerator( anchor_sizes, aspect_ratios )在之前提到,由于有 FPN 网络,所以输入 rpn 的是多个特征。为了方便介绍,以下都是以某一个特征进行描述,其他特征类似。
假设有 h ∗ w h*w h∗w 的特征,首先会计算这个特征相对于输入图像的下采样倍数 stride:
s t r i d e = i m a g e _ s i z e f e a t u r e s i z e stride = {image \_size \over featuresize } stride=featuresizeimage_size
然后生成一个 h ∗ w h*w h∗w 大小的网格,每个格子长度为 stride,如下图:
# AnchorGenerator.grid_anchors(...) shifts_x = torch.arange(0, grid_width, dtype=torch.float32, device=device) * stride_width shifts_y = torch.arange(0, grid_height, dtype=torch.float32, device=device) * stride_height shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)然后将 base_anchors 的中心从 移动到网格的点,且在网格的每个点都放置一组 base_anchors。这样就在当前 feature_map 上有了很多的 anchors。
需要特别说明,stride 代表网络的感受野,网络不可能检测到比 feature_map 更密集的框了!所以才只会在网格中每个点设置 anchors(反过来说,如果在网格的两个点之间设置 anchors,那么就对应 feature_map 中半个点,显然不合理)。
放置好 anchors 后,接下来就要调整网络,使网络输出能够判断每个 anchor 是否有目标,同时还要有 bounding box regression 需要的4个值 ( d x , d y , d w , d h ) (dx,dy,dw,dh) (dx,dy,dw,dh) 。
class RPNHead(nn.Module): def __init__(self, in_channels, num_anchors): super(RPNHead, self).__init__() self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) self.bbox_pred = nn.Conv2d( in_channels, num_anchors * 4, kernel_size=1, stride=1 ) def forward(self, x): logits = [] bbox_reg = [] for feature in x: t = F.relu(self.conv(feature)) logits.append(self.cls_logits(t)) bbox_reg.append(self.bbox_pred(t)) return logits, bbox_reg