Shortcuts

Source code for mmaction.datasets.pose_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Callable, Dict, List, Optional, Union

import mmengine
from mmengine.logging import MMLogger

from mmaction.registry import DATASETS
from .base import BaseActionDataset


[docs]@DATASETS.register_module() class PoseDataset(BaseActionDataset): """Pose dataset for action recognition. The dataset loads pose and apply specified transforms to return a dict containing pose information. The ann_file is a pickle file, the json file contains a list of annotations, the fields of an annotation include frame_dir(video_id), total_frames, label, kp, kpscore. Args: ann_file (str): Path to the annotation file. pipeline (list[dict | callable]): A sequence of data transforms. split (str, optional): The dataset split used. For UCF101 and HMDB51, allowed choices are 'train1', 'test1', 'train2', 'test2', 'train3', 'test3'. For NTURGB+D, allowed choices are 'xsub_train', 'xsub_val', 'xview_train', 'xview_val'. For NTURGB+D 120, allowed choices are 'xsub_train', 'xsub_val', 'xset_train', 'xset_val'. For FineGYM, allowed choices are 'train', 'val'. Defaults to None. valid_ratio (float, optional): The valid_ratio for videos in KineticsPose. For a video with n frames, it is a valid training sample only if n * valid_ratio frames have human pose. None means not applicable (only applicable to Kinetics Pose).Defaults to None. box_thr (float): The threshold for human proposals. Only boxes with confidence score larger than `box_thr` is kept. None means not applicable (only applicable to Kinetics). Allowed choices are 0.5, 0.6, 0.7, 0.8, 0.9. Defaults to 0.5. """ def __init__(self, ann_file: str, pipeline: List[Union[Dict, Callable]], split: Optional[str] = None, valid_ratio: Optional[float] = None, box_thr: float = 0.5, **kwargs) -> None: self.split = split self.box_thr = box_thr assert box_thr in [.5, .6, .7, .8, .9] self.valid_ratio = valid_ratio super().__init__( ann_file, pipeline=pipeline, modality='Pose', **kwargs)
[docs] def load_data_list(self) -> List[Dict]: """Load annotation file to get skeleton information.""" assert self.ann_file.endswith('.pkl') mmengine.exists(self.ann_file) data_list = mmengine.load(self.ann_file) if self.split is not None: split, annos = data_list['split'], data_list['annotations'] identifier = 'filename' if 'filename' in annos[0] else 'frame_dir' split = set(split[self.split]) data_list = [x for x in annos if x[identifier] in split] # Sometimes we may need to load video from the file if 'video' in self.data_prefix: for item in data_list: if 'filename' in item: item['filename'] = osp.join(self.data_prefix['video'], item['filename']) if 'frame_dir' in item: item['frame_dir'] = osp.join(self.data_prefix['video'], item['frame_dir']) return data_list
[docs] def filter_data(self) -> List[Dict]: """Filter out invalid samples.""" if self.valid_ratio is not None and isinstance( self.valid_ratio, float) and self.valid_ratio > 0: self.data_list = [ x for x in self.data_list if x['valid'][self.box_thr] / x['total_frames'] >= self.valid_ratio ] for item in self.data_list: assert 'box_score' in item,\ 'if valid_ratio is a positive number,' \ 'item should have field `box_score`' anno_inds = (item['box_score'] >= self.box_thr) item['anno_inds'] = anno_inds logger = MMLogger.get_current_instance() logger.info( f'{len(self.data_list)} videos remain after valid thresholding') return self.data_list
[docs] def get_data_info(self, idx: int) -> Dict: """Get annotation by index.""" data_info = super().get_data_info(idx) # Sometimes we may need to load skeleton from the file if 'skeleton' in self.data_prefix: identifier = 'filename' if 'filename' in data_info \ else 'frame_dir' ske_name = data_info[identifier] ske_path = osp.join(self.data_prefix['skeleton'], ske_name + '.pkl') ske = mmengine.load(ske_path) for k in ske: data_info[k] = ske[k] return data_info