Source code for mmaction.datasets.charades_sta_dataset
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, List, Optional, Union
import mmengine
import numpy as np
import torch
from mmengine.fileio import exists
from mmaction.registry import DATASETS
from mmaction.utils import ConfigType
from .base import BaseActionDataset
try:
import nltk
nltk_imported = True
except ImportError:
nltk_imported = False
[docs]@DATASETS.register_module()
class CharadesSTADataset(BaseActionDataset):
def __init__(self,
ann_file: str,
pipeline: List[Union[dict, Callable]],
word2id_file: str,
fps_file: str,
duration_file: str,
num_frames_file: str,
window_size: int,
ft_overlap: float,
data_prefix: Optional[ConfigType] = dict(video=''),
test_mode: bool = False,
**kwargs):
if not nltk_imported:
raise ImportError('nltk is required for CharadesSTADataset')
self.fps_info = mmengine.load(fps_file)
self.duration_info = mmengine.load(duration_file)
self.num_frames = mmengine.load(num_frames_file)
self.word2id = mmengine.load(word2id_file)
self.ft_interval = int(window_size * (1 - ft_overlap))
super().__init__(
ann_file,
pipeline=pipeline,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)
[docs] def load_data_list(self) -> List[dict]:
"""Load annotation file to get video information."""
exists(self.ann_file)
data_list = []
with open(self.ann_file) as f:
anno_database = f.readlines()
for item in anno_database:
first_part, query_sentence = item.strip().split('##')
query_sentence = query_sentence.replace('.', '')
query_words = nltk.word_tokenize(query_sentence)
query_tokens = [self.word2id[word] for word in query_words]
query_length = len(query_tokens)
query_tokens = torch.from_numpy(np.array(query_tokens))
vid_name, start_time, end_time = first_part.split()
duration = float(self.duration_info[vid_name])
fps = float(self.fps_info[vid_name])
gt_start_time = float(start_time)
gt_end_time = float(end_time)
gt_bbox = (gt_start_time / duration, min(gt_end_time / duration,
1))
num_frames = int(self.num_frames[vid_name])
proposal_frames = self.get_proposals(num_frames)
proposals = proposal_frames / num_frames
proposals = torch.from_numpy(proposals)
proposal_indexes = proposal_frames / self.ft_interval
proposal_indexes = proposal_indexes.astype(np.int32)
info = dict(
vid_name=vid_name,
fps=fps,
num_frames=num_frames,
duration=duration,
query_tokens=query_tokens,
query_length=query_length,
gt_start_time=gt_start_time,
gt_end_time=gt_end_time,
gt_bbox=gt_bbox,
proposals=proposals,
num_proposals=proposals.shape[0],
proposal_indexes=proposal_indexes)
data_list.append(info)
return data_list
def get_proposals(self, num_frames):
proposals = (num_frames - 1) / 32 * np.arange(33)
proposals = proposals.astype(np.int32)
proposals = np.stack([proposals[:-1], proposals[1:]]).T
return proposals
[docs] def get_data_info(self, idx: int) -> dict:
"""Get annotation by index."""
data_info = super().get_data_info(idx)
vid_name = data_info['vid_name']
feature_path = os.path.join(self.data_prefix['video'],
f'{vid_name}.pt')
vid_feature = torch.load(feature_path)
proposal_feats = []
proposal_indexes = data_info['proposal_indexes'].clip(
max=vid_feature.shape[0] - 1)
for s, e in proposal_indexes:
prop_feature, _ = vid_feature[s:e + 1].max(dim=0)
proposal_feats.append(prop_feature)
proposal_feats = torch.stack(proposal_feats)
data_info['raw_feature'] = proposal_feats
return data_info