Shortcuts

Source code for mmaction.datasets.msrvtt_datasets

# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import re
from collections import Counter
from typing import Dict, List

from mmengine.fileio import exists

from mmaction.registry import DATASETS
from .base import BaseActionDataset


[docs]@DATASETS.register_module() class MSRVTTVQA(BaseActionDataset): """MSR-VTT Video Question Answering dataset."""
[docs] def load_data_list(self) -> List[Dict]: """Load annotation file to get video information.""" exists(self.ann_file) data_list = [] with open(self.ann_file) as f: data_lines = json.load(f) for data in data_lines: answers = data['answer'] if isinstance(answers, str): answers = [answers] count = Counter(answers) answer_weight = [i / len(answers) for i in count.values()] data_item = dict( question_id=data['question_id'], filename=osp.join(self.data_prefix['video'], data['video']), question=pre_text(data['question']), gt_answer=list(count.keys()), gt_answer_weight=answer_weight) data_list.append(data_item) return data_list
[docs]@DATASETS.register_module() class MSRVTTVQAMC(BaseActionDataset): """MSR-VTT VQA multiple choices dataset."""
[docs] def load_data_list(self) -> List[Dict]: """Load annotation file to get video information.""" exists(self.ann_file) data_list = [] with open(self.ann_file) as f: data_lines = json.load(f) for data in data_lines: data_item = dict( filename=osp.join(self.data_prefix['video'], data['video']), label=data['answer'], caption_options=[pre_text(c) for c in data['caption']]) data_list.append(data_item) return data_list
[docs]@DATASETS.register_module() class MSRVTTRetrieval(BaseActionDataset): """MSR-VTT Retrieval dataset."""
[docs] def load_data_list(self) -> List[Dict]: """Load annotation file to get video information.""" exists(self.ann_file) data_list = [] with open(self.ann_file) as f: data_lines = json.load(f) video_idx = 0 text_idx = 0 for data in data_lines: # don't consider multiple videos or multiple captions video_path = osp.join(self.data_prefix['video'], data['video']) data_item = dict( filename=video_path, text=[], gt_video_id=[], gt_text_id=[]) if isinstance(data['caption'], str): data['caption'] = [data['caption']] for text in data['caption']: text = pre_text(text) data_item['text'].append(text) data_item['gt_video_id'].append(video_idx) data_item['gt_text_id'].append(text_idx) text_idx += 1 video_idx += 1 data_list.append(data_item) self.num_videos = video_idx self.num_texts = text_idx return data_list
def pre_text(text, max_l=None): text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person') text = re.sub(r'\s{2,}', ' ', text) text = text.rstrip('\n').strip(' ') if max_l: # truncate words = text.split(' ') if len(words) > max_l: text = ' '.join(words[:max_l]) return text