Shortcuts

Source code for mmaction.datasets.transforms.formatting

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Tuple

import numpy as np
import torch
from mmcv.transforms import BaseTransform, to_tensor
from mmengine.structures import InstanceData

from mmaction.registry import TRANSFORMS
from mmaction.structures import ActionDataSample


[docs]@TRANSFORMS.register_module() class PackActionInputs(BaseTransform): """Pack the inputs data. Args: collect_keys (tuple[str], optional): The keys to be collected to ``packed_results['inputs']``. Defaults to `` meta_keys (Sequence[str]): The meta keys to saved in the `metainfo` of the `data_sample`. Defaults to ``('img_shape', 'img_key', 'video_id', 'timestamp')``. algorithm_keys (Sequence[str]): The keys of custom elements to be used in the algorithm. Defaults to an empty tuple. """ mapping_table = { 'gt_bboxes': 'bboxes', 'gt_labels': 'labels', } def __init__( self, collect_keys: Optional[Tuple[str]] = None, meta_keys: Sequence[str] = ('img_shape', 'img_key', 'video_id', 'timestamp'), algorithm_keys: Sequence[str] = (), ) -> None: self.collect_keys = collect_keys self.meta_keys = meta_keys self.algorithm_keys = algorithm_keys
[docs] def transform(self, results: Dict) -> Dict: """The transform function of :class:`PackActionInputs`. Args: results (dict): The result dict. Returns: dict: The result dict. """ packed_results = dict() if self.collect_keys is not None: packed_results['inputs'] = dict() for key in self.collect_keys: packed_results['inputs'][key] = to_tensor(results[key]) else: if 'imgs' in results: imgs = results['imgs'] packed_results['inputs'] = to_tensor(imgs) elif 'heatmap_imgs' in results: heatmap_imgs = results['heatmap_imgs'] packed_results['inputs'] = to_tensor(heatmap_imgs) elif 'keypoint' in results: keypoint = results['keypoint'] packed_results['inputs'] = to_tensor(keypoint) elif 'audios' in results: audios = results['audios'] packed_results['inputs'] = to_tensor(audios) elif 'text' in results: text = results['text'] packed_results['inputs'] = to_tensor(text) else: raise ValueError( 'Cannot get `imgs`, `keypoint`, `heatmap_imgs`, ' '`audios` or `text` in the input dict of ' '`PackActionInputs`.') data_sample = ActionDataSample() if 'gt_bboxes' in results: instance_data = InstanceData() for key in self.mapping_table.keys(): instance_data[self.mapping_table[key]] = to_tensor( results[key]) data_sample.gt_instances = instance_data if 'proposals' in results: data_sample.proposals = InstanceData( bboxes=to_tensor(results['proposals'])) if 'label' in results: data_sample.set_gt_label(results['label']) # Set custom algorithm keys for key in self.algorithm_keys: if key in results: data_sample.set_field(results[key], key) # Set meta keys img_meta = {k: results[k] for k in self.meta_keys if k in results} data_sample.set_metainfo(img_meta) packed_results['data_samples'] = data_sample return packed_results
def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(collect_keys={self.collect_keys}, ' repr_str += f'meta_keys={self.meta_keys})' return repr_str
[docs]@TRANSFORMS.register_module() class PackLocalizationInputs(BaseTransform): def __init__(self, keys=(), meta_keys=('video_name', )): self.keys = keys self.meta_keys = meta_keys
[docs] def transform(self, results): """Method to pack the input data. Args: results (dict): Result dict from the data pipeline. Returns: dict: - 'inputs' (obj:`torch.Tensor`): The forward data of models. - 'data_samples' (obj:`DetDataSample`): The annotation info of the sample. """ packed_results = dict() if 'raw_feature' in results: raw_feature = results['raw_feature'] packed_results['inputs'] = to_tensor(raw_feature) elif 'bsp_feature' in results: packed_results['inputs'] = torch.tensor(0.) else: raise ValueError( 'Cannot get "raw_feature" or "bsp_feature" in the input ' 'dict of `PackActionInputs`.') data_sample = ActionDataSample() for key in self.keys: if key not in results: continue elif key == 'proposals': instance_data = InstanceData() instance_data[key] = to_tensor(results[key]) data_sample.proposals = instance_data else: if hasattr(data_sample, 'gt_instances'): data_sample.gt_instances[key] = to_tensor(results[key]) else: instance_data = InstanceData() instance_data[key] = to_tensor(results[key]) data_sample.gt_instances = instance_data img_meta = {k: results[k] for k in self.meta_keys if k in results} data_sample.set_metainfo(img_meta) packed_results['data_samples'] = data_sample return packed_results
def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(meta_keys={self.meta_keys})' return repr_str
[docs]@TRANSFORMS.register_module() class Transpose(BaseTransform): """Transpose image channels to a given order. Args: keys (Sequence[str]): Required keys to be converted. order (Sequence[int]): Image channel order. """ def __init__(self, keys, order): self.keys = keys self.order = order
[docs] def transform(self, results): """Performs the Transpose formatting. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ for key in self.keys: results[key] = results[key].transpose(self.order) return results
def __repr__(self): return (f'{self.__class__.__name__}(' f'keys={self.keys}, order={self.order})')
[docs]@TRANSFORMS.register_module() class FormatShape(BaseTransform): """Format final imgs shape to the given input_format. Required keys: - imgs (optional) - heatmap_imgs (optional) - modality (optional) - num_clips - clip_len Modified Keys: - imgs Added Keys: - input_shape - heatmap_input_shape (optional) Args: input_format (str): Define the final data format. collapse (bool): To collapse input_format N... to ... (NCTHW to CTHW, etc.) if N is 1. Should be set as True when training and testing detectors. Defaults to False. """ def __init__(self, input_format: str, collapse: bool = False) -> None: self.input_format = input_format self.collapse = collapse if self.input_format not in [ 'NCTHW', 'NCHW', 'NCTHW_Heatmap', 'NPTCHW' ]: raise ValueError( f'The input format {self.input_format} is invalid.')
[docs] def transform(self, results: Dict) -> Dict: """Performs the FormatShape formatting. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ if not isinstance(results['imgs'], np.ndarray): results['imgs'] = np.array(results['imgs']) # [M x H x W x C] # M = 1 * N_crops * N_clips * T if self.collapse: assert results['num_clips'] == 1 if self.input_format == 'NCTHW': if 'imgs' in results: imgs = results['imgs'] num_clips = results['num_clips'] clip_len = results['clip_len'] if isinstance(clip_len, dict): clip_len = clip_len['RGB'] imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:]) # N_crops x N_clips x T x H x W x C imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4)) # N_crops x N_clips x C x T x H x W imgs = imgs.reshape((-1, ) + imgs.shape[2:]) # M' x C x T x H x W # M' = N_crops x N_clips results['imgs'] = imgs results['input_shape'] = imgs.shape if 'heatmap_imgs' in results: imgs = results['heatmap_imgs'] num_clips = results['num_clips'] clip_len = results['clip_len'] # clip_len must be a dict clip_len = clip_len['Pose'] imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:]) # N_crops x N_clips x T x C x H x W imgs = np.transpose(imgs, (0, 1, 3, 2, 4, 5)) # N_crops x N_clips x C x T x H x W imgs = imgs.reshape((-1, ) + imgs.shape[2:]) # M' x C x T x H x W # M' = N_crops x N_clips results['heatmap_imgs'] = imgs results['heatmap_input_shape'] = imgs.shape elif self.input_format == 'NCTHW_Heatmap': num_clips = results['num_clips'] clip_len = results['clip_len'] imgs = results['imgs'] imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:]) # N_crops x N_clips x T x C x H x W imgs = np.transpose(imgs, (0, 1, 3, 2, 4, 5)) # N_crops x N_clips x C x T x H x W imgs = imgs.reshape((-1, ) + imgs.shape[2:]) # M' x C x T x H x W # M' = N_crops x N_clips results['imgs'] = imgs results['input_shape'] = imgs.shape elif self.input_format == 'NCHW': imgs = results['imgs'] imgs = np.transpose(imgs, (0, 3, 1, 2)) if 'modality' in results and results['modality'] == 'Flow': clip_len = results['clip_len'] imgs = imgs.reshape((-1, clip_len * imgs.shape[1]) + imgs.shape[2:]) # M x C x H x W results['imgs'] = imgs results['input_shape'] = imgs.shape elif self.input_format == 'NPTCHW': num_proposals = results['num_proposals'] num_clips = results['num_clips'] clip_len = results['clip_len'] imgs = results['imgs'] imgs = imgs.reshape((num_proposals, num_clips * clip_len) + imgs.shape[1:]) # P x M x H x W x C # M = N_clips x T imgs = np.transpose(imgs, (0, 1, 4, 2, 3)) # P x M x C x H x W results['imgs'] = imgs results['input_shape'] = imgs.shape if self.collapse: assert results['imgs'].shape[0] == 1 results['imgs'] = results['imgs'].squeeze(0) results['input_shape'] = results['imgs'].shape return results
def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f"(input_format='{self.input_format}')" return repr_str
[docs]@TRANSFORMS.register_module() class FormatAudioShape(BaseTransform): """Format final audio shape to the given input_format. Required Keys: - audios Modified Keys: - audios Added Keys: - input_shape Args: input_format (str): Define the final imgs format. """ def __init__(self, input_format: str) -> None: self.input_format = input_format if self.input_format not in ['NCTF']: raise ValueError( f'The input format {self.input_format} is invalid.')
[docs] def transform(self, results: Dict) -> Dict: """Performs the FormatShape formatting. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ audios = results['audios'] # clip x sample x freq -> clip x channel x sample x freq clip, sample, freq = audios.shape audios = audios.reshape(clip, 1, sample, freq) results['audios'] = audios results['input_shape'] = audios.shape return results
def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f"(input_format='{self.input_format}')" return repr_str
[docs]@TRANSFORMS.register_module() class FormatGCNInput(BaseTransform): """Format final skeleton shape. Required Keys: - keypoint - keypoint_score (optional) - num_clips (optional) Modified Key: - keypoint Args: num_person (int): The maximum number of people. Defaults to 2. mode (str): The padding mode. Defaults to ``'zero'``. """ def __init__(self, num_person: int = 2, mode: str = 'zero') -> None: self.num_person = num_person assert mode in ['zero', 'loop'] self.mode = mode
[docs] def transform(self, results: Dict) -> Dict: """The transform function of :class:`FormatGCNInput`. Args: results (dict): The result dict. Returns: dict: The result dict. """ keypoint = results['keypoint'] if 'keypoint_score' in results: keypoint = np.concatenate( (keypoint, results['keypoint_score'][..., None]), axis=-1) cur_num_person = keypoint.shape[0] if cur_num_person < self.num_person: pad_dim = self.num_person - cur_num_person pad = np.zeros( (pad_dim, ) + keypoint.shape[1:], dtype=keypoint.dtype) keypoint = np.concatenate((keypoint, pad), axis=0) if self.mode == 'loop' and cur_num_person == 1: for i in range(1, self.num_person): keypoint[i] = keypoint[0] elif cur_num_person > self.num_person: keypoint = keypoint[:self.num_person] M, T, V, C = keypoint.shape nc = results.get('num_clips', 1) assert T % nc == 0 keypoint = keypoint.reshape( (M, nc, T // nc, V, C)).transpose(1, 0, 2, 3, 4) results['keypoint'] = np.ascontiguousarray(keypoint) return results
def __repr__(self) -> str: repr_str = (f'{self.__class__.__name__}(' f'num_person={self.num_person}, ' f'mode={self.mode})') return repr_str