使用单阶段检测器作为 RPN¶
区域提议网络 (RPN) 是 Faster R-CNN 中的一个子模块,它为 Faster R-CNN 的第二阶段生成提议。 MMDetection 中的大多数两阶段检测器使用 RPNHead
来生成提议作为 RPN。 但是,任何单阶段检测器都可以作为 RPN,因为它们的边界框预测也可以被视为区域提议,因此可以在 R-CNN 中进行细化。 因此,MMDetection v3.0 支持这一点。
为了说明整个过程,这里我们给出一个如何使用无锚单阶段模型 FCOS 作为 Faster R-CNN 中 RPN 的示例。
本教程的大纲如下
在 Faster R-CNN 中使用
FCOSHead
作为RPNHead
评估提议
使用预训练的 FCOS 训练自定义的 Faster R-CNN
在 Faster R-CNN 中使用 FCOSHead
作为 RPNHead
¶
为了将 FCOSHead
设置为 Faster R-CNN 中的 RPNHead
,我们应该创建一个名为 configs/faster_rcnn/faster-rcnn_r50_fpn_fcos-rpn_1x_coco.py
的新配置文件,并将 rpn_head
的设置替换为 configs/fcos/fcos_r50-caffe_fpn_gn-head_1x_coco.py
中 bbox_head
的设置。 此外,我们仍然使用 FCOS 的颈部设置,步长为 [8, 16, 32, 64, 128]
,并将 bbox_roi_extractor
的 featmap_strides
更新为 [8, 16, 32, 64, 128]
。 为了避免损失变为 NAN,我们在前 1000 次迭代而不是前 500 次迭代中应用预热,这意味着学习率增加得更慢。 配置如下
_base_ = [
'../_base_/models/faster-rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
# copied from configs/fcos/fcos_r50-caffe_fpn_gn-head_1x_coco.py
neck=dict(
start_level=1,
add_extra_convs='on_output', # use P5
relu_before_extra_convs=True),
rpn_head=dict(
_delete_=True, # ignore the unused old settings
type='FCOSHead',
num_classes=1, # num_classes = 1 for rpn, if num_classes > 1, it will be set to 1 in TwoStageDetector automatically
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
roi_head=dict( # update featmap_strides due to the strides in neck
bbox_roi_extractor=dict(featmap_strides=[8, 16, 32, 64, 128])))
# learning rate
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
end=1000), # Slowly increase lr, otherwise loss becomes NAN
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
然后,我们可以使用以下命令来训练我们的自定义模型。 有关更多训练命令,请参考 这里.
# training with 8 GPUS
bash tools/dist_train.sh configs/faster_rcnn/faster-rcnn_r50_fpn_fcos-rpn_1x_coco.py \
8 \
--work-dir ./work_dirs/faster-rcnn_r50_fpn_fcos-rpn_1x_coco
评估提议¶
提议的质量对检测器的性能至关重要,因此,我们也提供了一种评估提议的方法。 与上面相同,创建一个名为 configs/rpn/fcos-rpn_r50_fpn_1x_coco.py
的新配置文件,并将 rpn_head
的设置替换为 configs/fcos/fcos_r50-caffe_fpn_gn-head_1x_coco.py
中 bbox_head
的设置。
_base_ = [
'../_base_/models/rpn_r50_fpn.py', '../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
val_evaluator = dict(metric='proposal_fast')
test_evaluator = val_evaluator
model = dict(
# copied from configs/fcos/fcos_r50-caffe_fpn_gn-head_1x_coco.py
neck=dict(
start_level=1,
add_extra_convs='on_output', # use P5
relu_before_extra_convs=True),
rpn_head=dict(
_delete_=True, # ignore the unused old settings
type='FCOSHead',
num_classes=1, # num_classes = 1 for rpn, if num_classes > 1, it will be set to 1 in RPN automatically
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
假设我们在训练后有检查点 ./work_dirs/faster-rcnn_r50_fpn_fcos-rpn_1x_coco/epoch_12.pth
,那么我们可以使用以下命令评估提议的质量。
# testing with 8 GPUs
bash tools/dist_test.sh \
configs/rpn/fcos-rpn_r50_fpn_1x_coco.py \
./work_dirs/faster-rcnn_r50_fpn_fcos-rpn_1x_coco/epoch_12.pth \
8
使用预训练的 FCOS 训练自定义的 Faster R-CNN¶
预训练不仅可以加快训练收敛速度,还可以提高检测器的性能。 因此,这里我们给出一个示例来说明如何使用预训练的 FCOS 作为 RPN 来加速训练并提高准确性。 假设我们想在 Faster R-CNN 中使用 FCOSHead
作为 rpn head 并使用预训练的 fcos_r50-caffe_fpn_gn-head_1x_coco
进行训练。 名为 configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_fcos-rpn_1x_coco.py
的配置文件的内容如下。 请注意,fcos_r50-caffe_fpn_gn-head_1x_coco
使用了 caffe 版本的 ResNet50,因此需要更新 data_preprocessor
中的像素均值和标准差。
_base_ = [
'../_base_/models/faster-rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
data_preprocessor=dict(
mean=[103.530, 116.280, 123.675],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False),
backbone=dict(
norm_cfg=dict(type='BN', requires_grad=False),
style='caffe',
init_cfg=None), # the checkpoint in ``load_from`` contains the weights of backbone
neck=dict(
start_level=1,
add_extra_convs='on_output', # use P5
relu_before_extra_convs=True),
rpn_head=dict(
_delete_=True, # ignore the unused old settings
type='FCOSHead',
num_classes=1, # num_classes = 1 for rpn, if num_classes > 1, it will be set to 1 in TwoStageDetector automatically
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
roi_head=dict( # update featmap_strides due to the strides in neck
bbox_roi_extractor=dict(featmap_strides=[8, 16, 32, 64, 128])))
load_from = 'https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_r50_caffe_fpn_gn-head_1x_coco/fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
训练命令如下。
bash tools/dist_train.sh \
configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_fcos-rpn_1x_coco.py \
8 \
--work-dir ./work_dirs/faster-rcnn_r50-caffe_fpn_fcos-rpn_1x_coco