Source code for mmaction.testing._utils
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmengine
import numpy as np
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
[docs]def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
[docs]def generate_backbone_demo_inputs(input_shape=(1, 3, 64, 64)):
"""Create a superset of inputs needed to run backbone.
Args:
input_shape (tuple): input batch dimensions.
Defaults to ``(1, 3, 64, 64)``.
"""
imgs = np.random.random(input_shape)
imgs = torch.FloatTensor(imgs)
return imgs
# TODO Remove this API
def generate_recognizer_demo_inputs(
input_shape=(1, 3, 3, 224, 224), model_type='2D'):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple): input batch dimensions.
Default: (1, 250, 3, 224, 224).
model_type (str): Model type for data generation, from {'2D', '3D'}.
Default:'2D'
"""
if len(input_shape) == 5:
(N, L, _, _, _) = input_shape
elif len(input_shape) == 6:
(N, M, _, L, _, _) = input_shape
imgs = np.random.random(input_shape)
if model_type == '2D' or model_type == 'skeleton':
gt_labels = torch.LongTensor([2] * N)
elif model_type == '3D':
gt_labels = torch.LongTensor([2] * M)
elif model_type == 'audio':
gt_labels = torch.LongTensor([2] * L)
else:
raise ValueError(f'Data type {model_type} is not available')
inputs = {'imgs': torch.FloatTensor(imgs), 'gt_labels': gt_labels}
return inputs
def generate_detector_demo_inputs(
input_shape=(1, 3, 4, 224, 224), num_classes=81, train=True,
device='cpu'):
num_samples = input_shape[0]
if not train:
assert num_samples == 1
def random_box(n):
box = torch.rand(n, 4) * 0.5
box[:, 2:] += 0.5
box[:, 0::2] *= input_shape[3]
box[:, 1::2] *= input_shape[4]
if device == 'cuda':
box = box.cuda()
return box
def random_label(n):
label = torch.randn(n, num_classes)
label = (label > 0.8).type(torch.float32)
label[:, 0] = 0
if device == 'cuda':
label = label.cuda()
return label
img = torch.FloatTensor(np.random.random(input_shape))
if device == 'cuda':
img = img.cuda()
proposals = [random_box(2) for i in range(num_samples)]
gt_bboxes = [random_box(2) for i in range(num_samples)]
gt_labels = [random_label(2) for i in range(num_samples)]
img_metas = [dict(img_shape=input_shape[-2:]) for i in range(num_samples)]
if train:
return dict(
img=img,
proposals=proposals,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
img_metas=img_metas)
return dict(img=[img], proposals=[proposals], img_metas=[img_metas])
[docs]def get_cfg(config_type, fname):
"""Grab configs necessary to create a recognizer.
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
config_types = ('recognition', 'recognition_audio', 'localization',
'detection', 'skeleton', 'retrieval')
assert config_type in config_types
repo_dpath = osp.dirname(osp.dirname(osp.dirname(__file__)))
config_dpath = osp.join(repo_dpath, 'configs/' + config_type)
config_fpath = osp.join(config_dpath, fname)
if not osp.exists(config_dpath):
raise Exception('Cannot find config path')
config = mmengine.Config.fromfile(config_fpath)
return config
def get_recognizer_cfg(fname):
return get_cfg('recognition', fname)
def get_audio_recognizer_cfg(fname):
return get_cfg('recognition_audio', fname)
def get_localizer_cfg(fname):
return get_cfg('localization', fname)
def get_detector_cfg(fname):
return get_cfg('detection', fname)
def get_skeletongcn_cfg(fname):
return get_cfg('skeleton', fname)
def get_similarity_cfg(fname):
return get_cfg('retrieval', fname)