Shortcuts

Source code for mmaction.datasets.ava_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import defaultdict
from typing import Callable, List, Optional, Union

import numpy as np
from mmengine.fileio import exists, list_from_file, load
from mmengine.logging import MMLogger

from mmaction.evaluation import read_labelmap
from mmaction.registry import DATASETS
from mmaction.utils import ConfigType
from .base import BaseActionDataset


[docs]@DATASETS.register_module() class AVADataset(BaseActionDataset): """STAD dataset for spatial temporal action detection. The dataset loads raw frames/video files, bounding boxes, proposals and applies specified transformations to return a dict containing the frame tensors and other information. This datasets can load information from the following files: .. code-block:: txt ann_file -> ava_{train, val}_{v2.1, v2.2}.csv exclude_file -> ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv label_file -> ava_action_list_{v2.1, v2.2}.pbtxt / ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt proposal_file -> ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl Particularly, the proposal_file is a pickle file which contains ``img_key`` (in format of ``{video_id},{timestamp}``). Example of a pickle file: .. code-block:: JSON { ... '0f39OWEqJ24,0902': array([[0.011 , 0.157 , 0.655 , 0.983 , 0.998163]]), '0f39OWEqJ24,0912': array([[0.054 , 0.088 , 0.91 , 0.998 , 0.068273], [0.016 , 0.161 , 0.519 , 0.974 , 0.984025], [0.493 , 0.283 , 0.981 , 0.984 , 0.983621]]), ... } Args: ann_file (str): Path to the annotation file like ``ava_{train, val}_{v2.1, v2.2}.csv``. exclude_file (str): Path to the excluded timestamp file like ``ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv``. pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of data transforms. label_file (str): Path to the label file like ``ava_action_list_{v2.1, v2.2}.pbtxt`` or ``ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt``. Defaults to None. filename_tmpl (str): Template for each filename. Defaults to 'img_{:05}.jpg'. start_index (int): Specify a start index for frames in consideration of different filename format. It should be set to 1 for AVA, since frame index start from 1 in AVA dataset. Defaults to 1. proposal_file (str): Path to the proposal file like ``ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl``. Defaults to None. person_det_score_thr (float): The threshold of person detection scores, bboxes with scores above the threshold will be used. Note that 0 <= person_det_score_thr <= 1. If no proposal has detection score larger than the threshold, the one with the largest detection score will be used. Default: 0.9. num_classes (int): The number of classes of the dataset. Default: 81. (AVA has 80 action classes, another 1-dim is added for potential usage) custom_classes (List[int], optional): A subset of class ids from origin dataset. Please note that 0 should NOT be selected, and ``num_classes`` should be equal to ``len(custom_classes) + 1``. data_prefix (dict or ConfigDict): Path to a directory where video frames are held. Defaults to ``dict(img='')``. test_mode (bool): Store True when building test or validation dataset. Defaults to False. modality (str): Modality of data. Support ``RGB``, ``Flow``. Defaults to ``RGB``. num_max_proposals (int): Max proposals number to store. Defaults to 1000. timestamp_start (int): The start point of included timestamps. The default value is referred from the official website. Defaults to 902. timestamp_end (int): The end point of included timestamps. The default value is referred from the official website. Defaults to 1798. use_frames (bool): Whether to use rawframes as input. Defaults to True. fps (int): Overrides the default FPS for the dataset. If set to 1, means counting timestamp by frame, e.g. MultiSports dataset. Otherwise by second. Defaults to 30. multilabel (bool): Determines whether it is a multilabel recognition task. Defaults to True. """ def __init__(self, ann_file: str, pipeline: List[Union[ConfigType, Callable]], exclude_file: Optional[str] = None, label_file: Optional[str] = None, filename_tmpl: str = 'img_{:05}.jpg', start_index: int = 1, proposal_file: str = None, person_det_score_thr: float = 0.9, num_classes: int = 81, custom_classes: Optional[List[int]] = None, data_prefix: ConfigType = dict(img=''), modality: str = 'RGB', test_mode: bool = False, num_max_proposals: int = 1000, timestamp_start: int = 900, timestamp_end: int = 1800, use_frames: bool = True, fps: int = 30, multilabel: bool = True, **kwargs) -> None: self._FPS = fps # Keep this as standard self.custom_classes = custom_classes if custom_classes is not None: assert num_classes == len(custom_classes) + 1 assert 0 not in custom_classes _, class_whitelist = read_labelmap(open(label_file)) assert set(custom_classes).issubset(class_whitelist) self.custom_classes = list([0] + custom_classes) self.exclude_file = exclude_file self.label_file = label_file self.proposal_file = proposal_file assert 0 <= person_det_score_thr <= 1, ( 'The value of ' 'person_det_score_thr should in [0, 1]. ') self.person_det_score_thr = person_det_score_thr self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end self.num_max_proposals = num_max_proposals self.filename_tmpl = filename_tmpl self.use_frames = use_frames self.multilabel = multilabel super().__init__( ann_file, pipeline=pipeline, data_prefix=data_prefix, test_mode=test_mode, num_classes=num_classes, start_index=start_index, modality=modality, **kwargs) if self.proposal_file is not None: self.proposals = load(self.proposal_file) else: self.proposals = None
[docs] def parse_img_record(self, img_records: List[dict]) -> tuple: """Merge image records of the same entity at the same time. Args: img_records (List[dict]): List of img_records (lines in AVA annotations). Returns: Tuple(list): A tuple consists of lists of bboxes, action labels and entity_ids. """ bboxes, labels, entity_ids = [], [], [] while len(img_records) > 0: img_record = img_records[0] num_img_records = len(img_records) selected_records = [ x for x in img_records if np.array_equal(x['entity_box'], img_record['entity_box']) ] num_selected_records = len(selected_records) img_records = [ x for x in img_records if not np.array_equal(x['entity_box'], img_record['entity_box']) ] assert len(img_records) + num_selected_records == num_img_records bboxes.append(img_record['entity_box']) valid_labels = np.array([ selected_record['label'] for selected_record in selected_records ]) # The format can be directly used by BCELossWithLogits if self.multilabel: label = np.zeros(self.num_classes, dtype=np.float32) label[valid_labels] = 1. else: label = valid_labels labels.append(label) entity_ids.append(img_record['entity_id']) bboxes = np.stack(bboxes) labels = np.stack(labels) entity_ids = np.stack(entity_ids) return bboxes, labels, entity_ids
[docs] def load_data_list(self) -> List[dict]: """Load AVA annotations.""" exists(self.ann_file) data_list = [] records_dict_by_img = defaultdict(list) fin = list_from_file(self.ann_file) for line in fin: line_split = line.strip().split(',') label = int(line_split[6]) if self.custom_classes is not None: if label not in self.custom_classes: continue label = self.custom_classes.index(label) video_id = line_split[0] timestamp = int(line_split[1]) # count by second or frame. img_key = f'{video_id},{timestamp:04d}' entity_box = np.array(list(map(float, line_split[2:6]))) entity_id = int(line_split[7]) if self.use_frames: shot_info = (0, (self.timestamp_end - self.timestamp_start) * self._FPS) # for video data, automatically get shot info when decoding else: shot_info = None video_info = dict( video_id=video_id, timestamp=timestamp, entity_box=entity_box, label=label, entity_id=entity_id, shot_info=shot_info) records_dict_by_img[img_key].append(video_info) for img_key in records_dict_by_img: video_id, timestamp = img_key.split(',') bboxes, labels, entity_ids = self.parse_img_record( records_dict_by_img[img_key]) ann = dict( gt_bboxes=bboxes, gt_labels=labels, entity_ids=entity_ids) frame_dir = video_id if self.data_prefix['img'] is not None: frame_dir = osp.join(self.data_prefix['img'], frame_dir) video_info = dict( frame_dir=frame_dir, video_id=video_id, timestamp=int(timestamp), img_key=img_key, shot_info=shot_info, fps=self._FPS, ann=ann) if not self.use_frames: video_info['filename'] = video_info.pop('frame_dir') data_list.append(video_info) return data_list
[docs] def filter_data(self) -> List[dict]: """Filter out records in the exclude_file.""" valid_indexes = [] if self.exclude_file is None: valid_indexes = list(range(len(self.data_list))) else: exclude_video_infos = [ x.strip().split(',') for x in open(self.exclude_file) ] for i, data_info in enumerate(self.data_list): valid_indexes.append(i) for video_id, timestamp in exclude_video_infos: if (data_info['video_id'] == video_id and data_info['timestamp'] == int(timestamp)): valid_indexes.pop() break logger = MMLogger.get_current_instance() logger.info(f'{len(valid_indexes)} out of {len(self.data_list)}' f' frames are valid.') data_list = [self.data_list[i] for i in valid_indexes] return data_list
[docs] def get_data_info(self, idx: int) -> dict: """Get annotation by index.""" data_info = super().get_data_info(idx) img_key = data_info['img_key'] data_info['filename_tmpl'] = self.filename_tmpl data_info['timestamp_start'] = self.timestamp_start data_info['timestamp_end'] = self.timestamp_end if self.proposals is not None: if img_key not in self.proposals: data_info['proposals'] = np.array([[0, 0, 1, 1]]) data_info['scores'] = np.array([1]) else: proposals = self.proposals[img_key] assert proposals.shape[-1] in [4, 5] if proposals.shape[-1] == 5: thr = min(self.person_det_score_thr, max(proposals[:, 4])) positive_inds = (proposals[:, 4] >= thr) proposals = proposals[positive_inds] proposals = proposals[:self.num_max_proposals] data_info['proposals'] = proposals[:, :4] data_info['scores'] = proposals[:, 4] else: proposals = proposals[:self.num_max_proposals] data_info['proposals'] = proposals assert data_info['proposals'].max() <= 1 and \ data_info['proposals'].min() >= 0, \ (f'relative proposals invalid: max value ' f'{data_info["proposals"].max()}, min value ' f'{data_info["proposals"].min()}') ann = data_info.pop('ann') data_info['gt_bboxes'] = ann['gt_bboxes'] data_info['gt_labels'] = ann['gt_labels'] data_info['entity_ids'] = ann['entity_ids'] return data_info
[docs]@DATASETS.register_module() class AVAKineticsDataset(BaseActionDataset): """AVA-Kinetics dataset for spatial temporal detection. Based on official AVA annotation files, the dataset loads raw frames, bounding boxes, proposals and applies specified transformations to return a dict containing the frame tensors and other information. This datasets can load information from the following files: .. code-block:: txt ann_file -> ava_{train, val}_{v2.1, v2.2}.csv exclude_file -> ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv label_file -> ava_action_list_{v2.1, v2.2}.pbtxt / ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt proposal_file -> ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl Particularly, the proposal_file is a pickle file which contains ``img_key`` (in format of ``{video_id},{timestamp}``). Example of a pickle file: .. code-block:: JSON { ... '0f39OWEqJ24,0902': array([[0.011 , 0.157 , 0.655 , 0.983 , 0.998163]]), '0f39OWEqJ24,0912': array([[0.054 , 0.088 , 0.91 , 0.998 , 0.068273], [0.016 , 0.161 , 0.519 , 0.974 , 0.984025], [0.493 , 0.283 , 0.981 , 0.984 , 0.983621]]), ... } Args: ann_file (str): Path to the annotation file like ``ava_{train, val}_{v2.1, v2.2}.csv``. exclude_file (str): Path to the excluded timestamp file like ``ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv``. pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of data transforms. label_file (str): Path to the label file like ``ava_action_list_{v2.1, v2.2}.pbtxt`` or ``ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt``. Defaults to None. filename_tmpl (str): Template for each filename. Defaults to 'img_{:05}.jpg'. start_index (int): Specify a start index for frames in consideration of different filename format. However, when taking frames as input, it should be set to 0, since frames from 0. Defaults to 0. proposal_file (str): Path to the proposal file like ``ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl``. Defaults to None. person_det_score_thr (float): The threshold of person detection scores, bboxes with scores above the threshold will be used. Note that 0 <= person_det_score_thr <= 1. If no proposal has detection score larger than the threshold, the one with the largest detection score will be used. Default: 0.9. num_classes (int): The number of classes of the dataset. Default: 81. (AVA has 80 action classes, another 1-dim is added for potential usage) custom_classes (List[int], optional): A subset of class ids from origin dataset. Please note that 0 should NOT be selected, and ``num_classes`` should be equal to ``len(custom_classes) + 1``. data_prefix (dict or ConfigDict): Path to a directory where video frames are held. Defaults to ``dict(img='')``. test_mode (bool): Store True when building test or validation dataset. Defaults to False. modality (str): Modality of data. Support ``RGB``, ``Flow``. Defaults to ``RGB``. num_max_proposals (int): Max proposals number to store. Defaults to 1000. timestamp_start (int): The start point of included timestamps. The default value is referred from the official website. Defaults to 902. timestamp_end (int): The end point of included timestamps. The default value is referred from the official website. Defaults to 1798. fps (int): Overrides the default FPS for the dataset. Defaults to 30. """ def __init__(self, ann_file: str, exclude_file: str, pipeline: List[Union[ConfigType, Callable]], label_file: str, filename_tmpl: str = 'img_{:05}.jpg', start_index: int = 0, proposal_file: str = None, person_det_score_thr: float = 0.9, num_classes: int = 81, custom_classes: Optional[List[int]] = None, data_prefix: ConfigType = dict(img=''), modality: str = 'RGB', test_mode: bool = False, num_max_proposals: int = 1000, timestamp_start: int = 900, timestamp_end: int = 1800, fps: int = 30, **kwargs) -> None: self._FPS = fps # Keep this as standard self.custom_classes = custom_classes if custom_classes is not None: assert num_classes == len(custom_classes) + 1 assert 0 not in custom_classes _, class_whitelist = read_labelmap(open(label_file)) assert set(custom_classes).issubset(class_whitelist) self.custom_classes = list([0] + custom_classes) self.exclude_file = exclude_file self.label_file = label_file self.proposal_file = proposal_file assert 0 <= person_det_score_thr <= 1, ( 'The value of ' 'person_det_score_thr should in [0, 1]. ') self.person_det_score_thr = person_det_score_thr self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end self.num_max_proposals = num_max_proposals self.filename_tmpl = filename_tmpl super().__init__( ann_file, pipeline=pipeline, data_prefix=data_prefix, test_mode=test_mode, num_classes=num_classes, start_index=start_index, modality=modality, **kwargs) if self.proposal_file is not None: self.proposals = load(self.proposal_file) else: self.proposals = None
[docs] def parse_img_record(self, img_records: List[dict]) -> tuple: """Merge image records of the same entity at the same time. Args: img_records (List[dict]): List of img_records (lines in AVA annotations). Returns: Tuple(list): A tuple consists of lists of bboxes, action labels and entity_ids. """ bboxes, labels, entity_ids = [], [], [] while len(img_records) > 0: img_record = img_records[0] num_img_records = len(img_records) selected_records = [ x for x in img_records if np.array_equal(x['entity_box'], img_record['entity_box']) ] num_selected_records = len(selected_records) img_records = [ x for x in img_records if not np.array_equal(x['entity_box'], img_record['entity_box']) ] assert len(img_records) + num_selected_records == num_img_records bboxes.append(img_record['entity_box']) valid_labels = np.array([ selected_record['label'] for selected_record in selected_records ]) # The format can be directly used by BCELossWithLogits label = np.zeros(self.num_classes, dtype=np.float32) label[valid_labels] = 1. labels.append(label) entity_ids.append(img_record['entity_id']) bboxes = np.stack(bboxes) labels = np.stack(labels) entity_ids = np.stack(entity_ids) return bboxes, labels, entity_ids
[docs] def filter_data(self) -> List[dict]: """Filter out records in the exclude_file.""" valid_indexes = [] if self.exclude_file is None: valid_indexes = list(range(len(self.data_list))) else: exclude_video_infos = [ x.strip().split(',') for x in open(self.exclude_file) ] for i, data_info in enumerate(self.data_list): valid_indexes.append(i) for video_id, timestamp in exclude_video_infos: if (data_info['video_id'] == video_id and data_info['timestamp'] == int(timestamp)): valid_indexes.pop() break logger = MMLogger.get_current_instance() logger.info(f'{len(valid_indexes)} out of {len(self.data_list)}' f' frames are valid.') data_list = [self.data_list[i] for i in valid_indexes] return data_list
def get_timestamp(self, video_id): if len(video_id) == 11: return self.timestamp_start, self.timestamp_end video_id = video_id.split('_') if len(video_id) >= 3: start = int(video_id[-2]) end = int(video_id[-1]) video_id = '_'.join(video_id[:-2]) return start, end return self.timestamp_start, self.timestamp_end
[docs] def load_data_list(self) -> List[dict]: """Load AVA annotations.""" exists(self.ann_file) data_list = [] records_dict_by_img = defaultdict(list) fin = list_from_file(self.ann_file) for line in fin: line_split = line.strip().split(',') label = int(line_split[6]) if self.custom_classes is not None: if label not in self.custom_classes: continue label = self.custom_classes.index(label) video_id = line_split[0] timestamp = int(line_split[1]) img_key = f'{video_id},{timestamp:04d}' entity_box = np.array(list(map(float, line_split[2:6]))) entity_id = int(line_split[7]) start, end = self.get_timestamp(video_id) shot_info = (1, (end - start) * self._FPS + 1) video_info = dict( video_id=video_id, timestamp=timestamp, entity_box=entity_box, label=label, entity_id=entity_id, shot_info=shot_info) records_dict_by_img[img_key].append(video_info) for img_key in records_dict_by_img: video_id, timestamp = img_key.split(',') start, end = self.get_timestamp(video_id) bboxes, labels, entity_ids = self.parse_img_record( records_dict_by_img[img_key]) ann = dict( gt_bboxes=bboxes, gt_labels=labels, entity_ids=entity_ids) frame_dir = video_id if self.data_prefix['img'] is not None: frame_dir = osp.join(self.data_prefix['img'], frame_dir) video_info = dict( frame_dir=frame_dir, video_id=video_id, timestamp=int(timestamp), timestamp_start=start, timestamp_end=end, img_key=img_key, shot_info=shot_info, fps=self._FPS, ann=ann) data_list.append(video_info) return data_list
[docs] def get_data_info(self, idx: int) -> dict: """Get annotation by index.""" data_info = super().get_data_info(idx) img_key = data_info['img_key'] data_info['filename_tmpl'] = self.filename_tmpl if 'timestamp_start' not in data_info: data_info['timestamp_start'] = self.timestamp_start data_info['timestamp_end'] = self.timestamp_end if self.proposals is not None: if len(img_key) == 16: proposal_key = img_key else: video_id, timestamp = img_key.split(',') vid = '_'.join(video_id.split('_')[:-2]) timestamp = int(timestamp) proposal_key = f'{vid},{timestamp:04d}' if proposal_key not in self.proposals: data_info['proposals'] = np.array([[0, 0, 1, 1]]) data_info['scores'] = np.array([1]) else: proposals = self.proposals[proposal_key] assert proposals.shape[-1] in [4, 5] if proposals.shape[-1] == 5: thr = min(self.person_det_score_thr, max(proposals[:, 4])) positive_inds = (proposals[:, 4] >= thr) proposals = proposals[positive_inds] proposals = proposals[:self.num_max_proposals] data_info['proposals'] = proposals[:, :4] data_info['scores'] = proposals[:, 4] else: proposals = proposals[:self.num_max_proposals] data_info['proposals'] = proposals ann = data_info.pop('ann') data_info['gt_bboxes'] = ann['gt_bboxes'] data_info['gt_labels'] = ann['gt_labels'] data_info['entity_ids'] = ann['entity_ids'] return data_info