mmdet.apis¶
mmdet.evaluation¶
functional¶
- mmdet.evaluation.functional.average_precision(recalls, precisions, mode='area')[source]¶
计算平均精度(针对单个或多个尺度)。
- 参数
recalls (ndarray) – 形状 (num_scales, num_dets) 或 (num_dets, )
precisions (ndarray) – 形状 (num_scales, num_dets) 或 (num_dets, )
mode (str) – ‘area’ 或 ‘11points’,‘area’ 表示计算精确率-召回率曲线的面积,‘11points’ 表示计算召回率在 [0, 0.1, …, 1] 处的平均精度。
- 返回值
计算的平均精度
- 返回类型
float 或 ndarray
- mmdet.evaluation.functional.bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-06, use_legacy_coordinate=False)[source]¶
计算 bboxes1 中每个 bbox 与 bboxes2 中每个 bbox 之间的 ious。
- 参数
bboxes1 (ndarray) – 形状 (n, 4)
bboxes2 (ndarray) – 形状 (k, 4)
mode (str) – IOU(交集比并集)或 IOF(交集比前景)
use_legacy_coordinate (bool) – 是否使用 mmdet v1.x 中的坐标系。这意味着宽度和高度应分别计算为 'x2 - x1 + 1' 和 'y2 - y1 + 1'。注意,当函数在 VOCDataset 中使用时,它应该为 True 才能与官方实现 http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar 对齐。默认值:False。
- 返回值
形状 (n, k)
- 返回类型
ious (ndarray)
- mmdet.evaluation.functional.eval_map(det_results, annotations, scale_ranges=None, iou_thr=0.5, ioa_thr=None, dataset=None, logger=None, tpfp_fn=None, nproc=4, use_legacy_coordinate=False, use_group_of=False, eval_mode='area')[source]¶
评估数据集的 mAP。
- 参数
det_results (list[list]) – [[cls1_det, cls2_det, …], …]。外层列表表示图像,内层列表表示每个类别的检测到的 bbox。
annotations (list[dict]) –
真实值标注,列表的每个项目表示一张图像。标注的键是
bboxes: 形状为 (n, 4) 的 numpy 数组
labels: 形状为 (n, ) 的 numpy 数组
bboxes_ignore (可选): 形状为 (k, 4) 的 numpy 数组
labels_ignore (可选): 形状为 (k, ) 的 numpy 数组
scale_ranges (list[tuple] | None) – 要评估的尺度范围,格式为 [(min1, max1), (min2, max2), …]。范围 (32, 64) 表示面积范围在 (32**2, 64**2) 之间。默认值:None。
iou_thr (float) – 被视为匹配的 IoU 阈值。默认值:0.5。
ioa_thr (float | None) – 被视为匹配的 IoA 阈值,仅用于 OpenImages 评估。默认值:None。
dataset (list[str] | str | None) – 数据集名称或数据集类别,不同数据集的指标略有差异,例如 “voc”、“imagenet_det” 等。默认值:None。
logger (logging.Logger | str | None) – 打印 mAP 摘要的方式。有关详细信息,请参见 mmengine.logging.print_log()。默认值:None。
tpfp_fn (可调用对象 | None) – 用于确定真阳性/假阳性的函数。如果为 None,则使用
tpfp_default()
作为默认值,除非数据集为 'det' 或 'vid'(在这种情况下使用tpfp_imagenet()
)。如果它被指定为一个函数,那么这个函数将用于评估 tp 和 fp。默认值为 None。nproc (int) – 用于计算 TP 和 FP 的进程数。默认为 4。
use_legacy_coordinate (bool) – 是否使用 mmdet v1.x 中的坐标系。这意味着宽度和高度应分别计算为 'x2 - x1 + 1` 和 'y2 - y1 + 1`。默认为 False。
use_group_of (bool) – 是否在计算 TP 和 FP 时使用 group of,这只在 OpenImages 评估中使用。默认为 False。
eval_mode (str) – 'area' 或 '11points','area' 表示计算精确率-召回率曲线下的面积,'11points' 表示计算召回率为 [0, 0.1, …, 1] 时,PASCAL VOC2007 使用 11points 作为默认评估模式,而其他模式为 'area'。默认为 'area'。
- 返回值
(mAP, [dict, dict, …])
- 返回类型
元组
- mmdet.evaluation.functional.eval_recalls(gts, proposals, proposal_nums=None, iou_thrs=0.5, logger=None, use_legacy_coordinate=False)[source]¶
计算召回率。
- 参数
gts (list[ndarray]) – 形状为 (n, 4) 的数组列表
proposals (list[ndarray]) – 形状为 (k, 4) 或 (k, 5) 的数组列表
proposal_nums (int | Sequence[int]) – 要评估的 Top N 个 proposals。
iou_thrs (float | Sequence[float]) – IoU 阈值。默认值:0.5。
logger (logging.Logger | str | None) – 打印召回率摘要的方式。有关详细信息,请参阅 mmengine.logging.print_log()。默认值:None。
use_legacy_coordinate (bool) – 是否使用 mmdet v1.x 中的坐标系。“1” 被添加到高度和宽度中,这意味着 w 和 h 应计算为 'x2 - x1 + 1` 和 'y2 - y1 + 1`。默认为 False。
- 返回值
不同 IoU 和 proposal 数量的召回率
- 返回类型
ndarray
- mmdet.evaluation.functional.evaluateImgLists(prediction_list: list, groundtruth_list: list, args: object, backend_args: Optional[dict] = None, dump_matches: bool = False) → dict[source]¶
obj 的包装器:``cityscapesscripts.evaluation.
evalInstanceLevelSemanticLabeling.evaluateImgLists``。支持从文件后端加载 groundtruth 图像。:param prediction_list: 预测 txt 文件列表。:type prediction_list: list :param groundtruth_list: groundtruth 图像文件列表。:type groundtruth_list: list :param args: obj 中的全局对象设置
obj:
cityscapesscripts.evaluation. evalInstanceLevelSemanticLabeling
- 参数
backend_args (dict, optional) – 实例化与后端相对应的 uri 的前缀的参数。默认为 None。
dump_matches (bool) – 是否转储 matches.json。默认为 False。
- 返回值
计算的指标。
- 返回类型
dict
- mmdet.evaluation.functional.plot_iou_recall(recalls, iou_thrs)[source]¶
绘制 IoU-召回率曲线。
- 参数
recalls (ndarray or list) – 形状为 (k,) 的数组
iou_thrs (ndarray or list) – 与 recalls 形状相同的数组
- mmdet.evaluation.functional.plot_num_recall(recalls, proposal_nums)[source]¶
绘制 Proposal_num-召回率曲线。
- 参数
recalls (ndarray or list) – 形状为 (k,) 的数组
proposal_nums (ndarray or list) – 与 recalls 形状相同的数组
- mmdet.evaluation.functional.pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories, backend_args=None, nproc=32)[source]¶
使用多线程评估全景分割的指标。
与 panopticapi 中同名函数相同。
- 参数
matched_annotations_list (list) – 匹配的标注列表。每个元素都是同一图像标注的元组,格式为 (gt_anns, pred_anns)。
gt_folder (str) – ground truth 图像的路径。
pred_folder (str) – 预测图像的路径。
categories (str) – 数据集的类别。
backend_args (object) – 数据集的文件客户端。如果为 None,则后端将设置为 local。
nproc (int) – 用于全景质量计算的进程数。默认为 32。当 nproc 超过 CPU 内核数量时,将使用 CPU 内核数量。
- mmdet.evaluation.functional.pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, categories, backend_args=None, print_log=False)[source]¶
用于评估全景分割指标的单核函数。
与panopticapi中同名函数相同。只是加载图像的函数已更改为使用文件客户端。
- 参数
proc_id (int) – 微型进程的 ID。
gt_folder (str) – ground truth 图像的路径。
pred_folder (str) – 预测图像的路径。
categories (str) – 数据集的类别。
backend_args (object) – 数据集的后端。如果为 None,则后端将设置为 local。
print_log (bool) – 是否打印日志。默认为 False。
- mmdet.evaluation.functional.print_map_summary(mean_ap, results, dataset=None, scale_ranges=None, logger=None)[source]¶
打印 mAP 和每个类别的结果。
将打印一个表格,显示每个类别的 gts/dets/recall/AP 以及 mAP。
- 参数
mean_ap (float) – 从 eval_map() 中计算得出。
results (list[dict]) – 从 eval_map() 中计算得出。
dataset (list[str] | str | None) – 数据集名称或数据集类别。
scale_ranges (list[tuple] | None) – 要评估的尺度范围。
logger (logging.Logger | str | None) – 打印 mAP 摘要的方式。有关详细信息,请参见 mmengine.logging.print_log()。默认值:None。
- mmdet.evaluation.functional.print_recall_summary(recalls, proposal_nums, iou_thrs, row_idxs=None, col_idxs=None, logger=None)[source]¶
以表格形式打印召回率。
- 参数
recalls (ndarray) – 从 bbox_recalls 计算得出
proposal_nums (ndarray or list) – 前 N 个建议
iou_thrs (ndarray or list) – iou 阈值
row_idxs (ndarray) – 要打印的行(建议数)
col_idxs (ndarray) – 要打印的列(iou 阈值)
logger (logging.Logger | str | None) – 打印召回率摘要的方式。有关详细信息,请参阅 mmengine.logging.print_log()。默认值:None。
metrics¶
mmdet.models¶
backbones¶
data_preprocessors¶
dense_heads¶
detectors¶
layers¶
losses¶
necks¶
roi_heads¶
seg_heads¶
task_modules¶
test_time_augs¶
utils¶
mmdet.structures¶
structures¶
- class mmdet.structures.DetDataSample(*, metainfo: Optional[dict] = None, **kwargs)[source]¶
MMDetection 的数据结构接口。它们用作不同组件之间的接口。
DetDataSample
中的属性分为几个部分- ``proposals``(InstanceData): 用于两阶段检测器的区域建议。
detectors。
``gt_instances``(InstanceData): 实例标注的真实值。
``pred_instances``(InstanceData): 检测预测的实例。
- ``pred_track_instances``(InstanceData): 跟踪预测的实例。
predictions。
- ``ignored_instances``(InstanceData): 在训练/测试期间要忽略的实例。
training/testing。
- ``gt_panoptic_seg``(PixelData): 全景分割的真实值。
segmentation。
- ``pred_panoptic_seg``(PixelData): 全景分割的预测。
segmentation。
``gt_sem_seg``(PixelData): 语义分割的真实值。
``pred_sem_seg``(PixelData): 语义分割的预测。
Examples
>>> import torch >>> import numpy as np >>> from mmengine.structures import InstanceData >>> from mmdet.structures import DetDataSample
>>> data_sample = DetDataSample() >>> img_meta = dict(img_shape=(800, 1196), ... pad_shape=(800, 1216)) >>> gt_instances = InstanceData(metainfo=img_meta) >>> gt_instances.bboxes = torch.rand((5, 4)) >>> gt_instances.labels = torch.rand((5,)) >>> data_sample.gt_instances = gt_instances >>> assert 'img_shape' in data_sample.gt_instances.metainfo_keys() >>> len(data_sample.gt_instances) 5 >>> print(data_sample)
<DetDataSample(
META INFORMATION
DATA FIELDS gt_instances: <InstanceData(
META INFORMATION pad_shape: (800, 1216) img_shape: (800, 1196)
DATA FIELDS labels: tensor([0.8533, 0.1550, 0.5433, 0.7294, 0.5098]) bboxes: tensor([[9.7725e-01, 5.8417e-01, 1.7269e-01, 6.5694e-01],
[1.7894e-01, 5.1780e-01, 7.0590e-01, 4.8589e-01], [7.0392e-01, 6.6770e-01, 1.7520e-01, 1.4267e-01], [2.2411e-01, 5.1962e-01, 9.6953e-01, 6.6994e-01], [4.1338e-01, 2.1165e-01, 2.7239e-04, 6.8477e-01]])
) at 0x7f21fb1b9190>
- ) at 0x7f21fb1b9880>
>>> pred_instances = InstanceData(metainfo=img_meta) >>> pred_instances.bboxes = torch.rand((5, 4)) >>> pred_instances.scores = torch.rand((5,)) >>> data_sample = DetDataSample(pred_instances=pred_instances) >>> assert 'pred_instances' in data_sample
>>> pred_track_instances = InstanceData(metainfo=img_meta) >>> pred_track_instances.bboxes = torch.rand((5, 4)) >>> pred_track_instances.scores = torch.rand((5,)) >>> data_sample = DetDataSample( ... pred_track_instances=pred_track_instances) >>> assert 'pred_track_instances' in data_sample
>>> data_sample = DetDataSample() >>> gt_instances_data = dict( ... bboxes=torch.rand(2, 4), ... labels=torch.rand(2), ... masks=np.random.rand(2, 2, 2)) >>> gt_instances = InstanceData(**gt_instances_data) >>> data_sample.gt_instances = gt_instances >>> assert 'gt_instances' in data_sample >>> assert 'masks' in data_sample.gt_instances
>>> data_sample = DetDataSample() >>> gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(2, 4)) >>> gt_panoptic_seg = PixelData(**gt_panoptic_seg_data) >>> data_sample.gt_panoptic_seg = gt_panoptic_seg >>> print(data_sample)
<DetDataSample(
META INFORMATION
DATA FIELDS _gt_panoptic_seg: <BaseDataElement(
META INFORMATION
DATA FIELDS panoptic_seg: tensor([[0.7586, 0.1262, 0.2892, 0.9341],
[0.3200, 0.7448, 0.1052, 0.5371]])
) at 0x7f66c2bb7730>
gt_panoptic_seg: <BaseDataElement(
META INFORMATION
DATA FIELDS panoptic_seg: tensor([[0.7586, 0.1262, 0.2892, 0.9341],
[0.3200, 0.7448, 0.1052, 0.5371]])
) at 0x7f66c2bb7730>
) at 0x7f66c2bb7280> >>> data_sample = DetDataSample() >>> gt_segm_seg_data = dict(segm_seg=torch.rand(2, 2, 2)) >>> gt_segm_seg = PixelData(**gt_segm_seg_data) >>> data_sample.gt_segm_seg = gt_segm_seg >>> assert ‘gt_segm_seg’ in data_sample >>> assert ‘segm_seg’ in data_sample.gt_segm_seg
- class mmdet.structures.ReIDDataSample(*, metainfo: Optional[dict] = None, **kwargs)[source]¶
ReID 任务的数据结构接口。
它用作不同组件之间的接口。
- Meta field
- img_shape (Tuple): 对应输入图像的形状。
用于可视化。
- ori_shape (Tuple): 对应图像的原始形状。
用于可视化。
- num_classes (int): 所有类别的数量。
用于标签格式转换。
- Data field
gt_label (LabelData): 真实值标签。pred_label (LabelData): 预测标签。scores (torch.Tensor): 模型的输出。
- set_gt_label(value: Union[numpy.ndarray, torch.Tensor, Sequence[numbers.Number], numbers.Number]) → mmdet.structures.reid_data_sample.ReIDDataSample[source]¶
设置
gt_label
的标签。
- set_gt_score(value: torch.Tensor) → mmdet.structures.reid_data_sample.ReIDDataSample[source]¶
设置
gt_label
的分数。
- class mmdet.structures.TrackDataSample(*, metainfo: Optional[dict] = None, **kwargs)[source]¶
MMDetection 中跟踪任务的数据结构接口。它用作不同组件之间的接口。
从某种程度上说,这种数据结构可以看作是多个 DetDataSample 的包装器。具体来说,它只包含一个属性:
video_data_samples
,它是一个 DetDataSample 列表,每个 DetDataSample 对应一帧。如果想获取单帧的属性,首先要通过索引获取相应的DetDataSample
,然后获取帧的属性,比如gt_instances
、pred_instances
等等。至于元信息,它与DetDataSample
不同,每个值对应元信息键是一个列表,其中每个元素对应单帧的信息。Examples
>>> import torch >>> from mmengine.structures import InstanceData >>> from mmdet.structures import DetDataSample, TrackDataSample >>> track_data_sample = TrackDataSample() >>> # set the 1st frame >>> frame1_data_sample = DetDataSample(metainfo=dict( ... img_shape=(100, 100), frame_id=0)) >>> frame1_gt_instances = InstanceData() >>> frame1_gt_instances.bbox = torch.zeros([2, 4]) >>> frame1_data_sample.gt_instances = frame1_gt_instances >>> # set the 2nd frame >>> frame2_data_sample = DetDataSample(metainfo=dict( ... img_shape=(100, 100), frame_id=1)) >>> frame2_gt_instances = InstanceData() >>> frame2_gt_instances.bbox = torch.ones([3, 4]) >>> frame2_data_sample.gt_instances = frame2_gt_instances >>> track_data_sample.video_data_samples = [frame1_data_sample, ... frame2_data_sample] >>> # set metainfo for track_data_sample >>> track_data_sample.set_metainfo(dict(key_frames_inds=[0])) >>> track_data_sample.set_metainfo(dict(ref_frames_inds=[1])) >>> print(track_data_sample) <TrackDataSample(
META INFORMATION key_frames_inds: [0] ref_frames_inds: [1]
DATA FIELDS video_data_samples: [<DetDataSample(
META INFORMATION img_shape: (100, 100)
DATA FIELDS gt_instances: <InstanceData(
META INFORMATION
DATA FIELDS bbox: tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
) at 0x7f639320dcd0>
) at 0x7f64bd223340>, <DetDataSample(
META INFORMATION img_shape: (100, 100)
DATA FIELDS gt_instances: <InstanceData(
META INFORMATION
DATA FIELDS bbox: tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.], [1., 1., 1., 1.]])
) at 0x7f64bd128b20>
) at 0x7f64bd1346d0>]
) at 0x7f64bd2237f0> >>> print(len(track_data_sample)) 2 >>> key_data_sample = track_data_sample.get_key_frames() >>> print(key_data_sample[0].frame_id) 0 >>> ref_data_sample = track_data_sample.get_ref_frames() >>> print(ref_data_sample[0].frame_id) 1 >>> frame1_data_sample = track_data_sample[0] >>> print(frame1_data_sample.gt_instances.bbox) tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> # Tensor-like methods >>> cuda_track_data_sample = track_data_sample.to('cuda') >>> cuda_track_data_sample = track_data_sample.cuda() >>> cpu_track_data_sample = track_data_sample.cpu() >>> cpu_track_data_sample = track_data_sample.to('cpu') >>> fp16_instances = cuda_track_data_sample.to( ... device=None, dtype=torch.float16, non_blocking=False, ... copy=False, memory_format=torch.preserve_format)
- clone() → mmengine.structures.base_data_element.BaseDataElement[source]¶
深度复制当前数据元素。
- 返回值
当前数据元素的副本。
- 返回类型
BaseDataElement
bbox¶
mask¶
mmdet.testing¶
mmdet.visualization¶
mmdet.utils¶
- class mmdet.utils.AvoidOOM(to_cpu=True, test=False)[source]¶
如果遇到 PyTorch 的 CUDA 内存不足错误,尝试将输入转换为 FP16 和 CPU。它将执行以下步骤
首先,在调用 torch.cuda.empty_cache() 后重试。
如果仍然失败,它将通过将输入
转换为 FP16 来重试。
如果仍然失败,尝试将输入转换为 CPU。
在这种情况下,它希望该函数分派到 CPU 实现。
- 参数
to_cpu (bool) – 如果遇到 OOM 错误,是否将输出转换为 CPU。这会显着降低代码速度。默认为 True。
test (bool) – 跳过 _ignore_torch_cuda_oom 操作,该操作可以在单元测试中使用轻量级数据,仅在单元测试中使用。默认为 False。
Examples
>>> from mmdet.utils.memory import AvoidOOM >>> AvoidCUDAOOM = AvoidOOM() >>> output = AvoidOOM.retry_if_cuda_oom( >>> some_torch_function)(input1, input2) >>> # To use as a decorator >>> # from mmdet.utils import AvoidCUDAOOM >>> @AvoidCUDAOOM.retry_if_cuda_oom >>> def function(*args, **kwargs): >>> return None
注意
- 即使输入在 GPU 上,输出也可能在 CPU 上。处理
在 CPU 上会显着降低代码速度。
- 在将输入转换为 CPU 时,它将只查看每个参数
并检查它是否具有 .device 和 .to 进行转换。不支持张量的嵌套结构。
- 由于该函数可能被调用多次,因此它必须是
无状态的。
- retry_if_cuda_oom(func)[source]¶
使函数在遇到 pytorch 的 CUDA OOM 错误后重试自身。
实现逻辑参考 https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py
- 参数
func – 一个无状态的可调用对象,它将张量状对象作为参数。
- 返回值
一个可调用对象,它在遇到 OOM 时重试 func。
- 返回类型
func
- mmdet.utils.all_reduce_dict(py_dict, op='sum', group=None, to_float=True)[source]¶
对 Python 字典对象应用所有归约函数。
代码修改自 https://github.com/Megvii- BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py。
注意:确保不同等级的 py_dict 具有相同的键,并且值应该具有相同的形状。目前只支持 nccl 后端。
- 参数
py_dict (dict) – 要应用所有归约操作的字典。
op (str) – 操作符,可以是 ‘sum’ 或 ‘mean’。默认:‘sum’
group (
torch.distributed.group
, optional) – 分布式组,默认:None。to_float (bool) – 是否将字典的所有值转换为浮点数。默认:True。
- 返回值
已归约的 Python 字典对象。
- 返回类型
OrderedDict
- mmdet.utils.allreduce_grads(params, coalesce=True, bucket_size_mb=- 1)[source]¶
所有归约梯度。
- 参数
params (list[torch.Parameters]) – 模型的参数列表
coalesce (bool, optional) – 是否将所有归约参数作为一个整体。默认为 True。
bucket_size_mb (int, optional) – 桶的大小,单位为 MB。默认为 -1。
- mmdet.utils.find_latest_checkpoint(path, suffix='pth')[source]¶
从工作目录中找到最新的检查点。
- 参数
path (str) – 查找检查点的路径。
suffix (str) – 文件扩展名。默认为 pth。
- 返回值
最新检查点的文件路径。
- 返回类型
latest_path(str | None)
参考
- 1
https://github.com/microsoft/SoftTeacher /blob/main/ssod/utils/patch.py
- mmdet.utils.get_test_pipeline_cfg(cfg: Union[str, mmengine.config.config.ConfigDict]) → mmengine.config.config.ConfigDict[source]¶
从整个配置文件中获取测试数据集管道。
- 参数
cfg (str 或
ConfigDict
) – 整个配置文件。可以是配置文件或ConfigDict
。- 返回值
测试数据集的配置。
- 返回类型
ConfigDict
- mmdet.utils.imshow_mot_errors(*args, backend: str = 'cv2', **kwargs)[source]¶
在输入图像上显示错误的轨迹。
- 参数
backend (str, optional) – 可视化的后端。默认为 ‘cv2’。
- mmdet.utils.log_img_scale(img_scale, shape_order='hw', skip_square=False)[source]¶
记录图像大小。
- 参数
img_scale (tuple) – 要记录的图像大小。
shape_order (str, optional) – 图像形状的顺序。 ‘hw’ 代表 (高度,宽度), ‘wh’ 代表 (宽度,高度)。默认为 ‘hw’。
skip_square (bool, optional) – 是否跳过记录正方形 img_scale。默认为 False。
- 返回值
是否已完成记录。
- 返回类型
bool
- mmdet.utils.register_all_modules(init_default_scope: bool = True) → None[source]¶
将 mmdet 中的所有模块注册到注册表中。
- 参数
init_default_scope (bool) – 是否初始化 mmdet 默认范围。当 init_default_scope=True 时,全局默认范围将设置为 mmdet,并且所有注册表将从 mmdet 的注册表节点构建模块。要了解有关注册表的更多信息,请参考 https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md 默认为 True。
- mmdet.utils.replace_cfg_vals(ori_cfg)[source]¶
将字符串 “${key}” 替换为相应的 value。
将 “${key}” 替换为配置文件中 ori_cfg.key 的 value。并支持替换链式的 ${key}。例如,将 “${key0.key1}” 替换为 cfg.key0.key1 的 value。代码修改自 `vars.py < https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501
- 参数
ori_cfg (mmengine.config.Config) – 从文件生成的带有 “${key}” 的原始配置。
- 返回值
将 “${key}” 替换为相应 value 的配置。
- 返回类型
updated_cfg [mmengine.config.Config]
- mmdet.utils.setup_cache_size_limit_of_dynamo()[source]¶
设置 dynamo 的缓存大小限制。
注意:由于目标检测算法中损失计算和后处理部分的动态形状,这些函数必须在每次运行时进行编译。为 torch._dynamo.config.cache_size_limit 设置较大的值可能会导致重复编译,从而降低训练和测试速度。因此,我们需要将 cache_size_limit 的默认值设置为更小。一个经验值是 4。
- mmdet.utils.split_batch(img, img_metas, kwargs)[source]¶
根据标签拆分 data_batch。
代码修改自 <https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/structure_utils.py> # noqa: E501
- 参数
img (Tensor) – 形状为 (N, C, H, W) 的输入图像编码。通常情况下,这些图像应该经过均值中心化和标准化缩放。
img_metas (list[dict]) – 图像信息字典列表,每个字典包含: ‘img_shape’, ‘scale_factor’, ‘flip’,以及可能包含 ‘filename’, ‘ori_shape’, ‘pad_shape’, ‘img_norm_cfg’。有关这些键值的详细信息,请参考
mmdet.datasets.pipelines.Collect
。kwargs (dict) – 特定于具体实现。
- 返回值
- 一个字典,其中 data_batch 按标签拆分,
例如 ‘sup’, ‘unsup_teacher’ 和 ‘unsup_student’。
- 返回类型
data_groups (dict)
- mmdet.utils.sync_random_seed(seed=None, device='cuda')[source]¶
确保不同的 rank 共享相同的种子。
所有 worker 必须调用此函数,否则会死锁。此方法通常用于 DistributedSampler,因为种子在分布式组中的所有进程中都应该是相同的。
在分布式采样中,不同的 rank 应该从数据集中采样不重叠的数据。因此,此函数用于确保每个 rank 基于相同的种子以相同的顺序对数据索引进行洗牌。然后,不同的 rank 可以使用不同的索引从同一个数据列表中选择不重叠的数据。
- 参数
seed (int, Optional) – 种子。默认为 None。
device (str) – 种子将放置的设备。默认为 ‘cuda’。
- 返回值
要使用的种子。
- 返回类型
int