Source code for mmaction.datasets.msrvtt_datasets
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import re
from collections import Counter
from typing import Dict, List
from mmengine.fileio import exists
from mmaction.registry import DATASETS
from .base import BaseActionDataset
[docs]@DATASETS.register_module()
class MSRVTTVQA(BaseActionDataset):
"""MSR-VTT Video Question Answering dataset."""
[docs] def load_data_list(self) -> List[Dict]:
"""Load annotation file to get video information."""
exists(self.ann_file)
data_list = []
with open(self.ann_file) as f:
data_lines = json.load(f)
for data in data_lines:
answers = data['answer']
if isinstance(answers, str):
answers = [answers]
count = Counter(answers)
answer_weight = [i / len(answers) for i in count.values()]
data_item = dict(
question_id=data['question_id'],
filename=osp.join(self.data_prefix['video'],
data['video']),
question=pre_text(data['question']),
gt_answer=list(count.keys()),
gt_answer_weight=answer_weight)
data_list.append(data_item)
return data_list
[docs]@DATASETS.register_module()
class MSRVTTVQAMC(BaseActionDataset):
"""MSR-VTT VQA multiple choices dataset."""
[docs] def load_data_list(self) -> List[Dict]:
"""Load annotation file to get video information."""
exists(self.ann_file)
data_list = []
with open(self.ann_file) as f:
data_lines = json.load(f)
for data in data_lines:
data_item = dict(
filename=osp.join(self.data_prefix['video'],
data['video']),
label=data['answer'],
caption_options=[pre_text(c) for c in data['caption']])
data_list.append(data_item)
return data_list
[docs]@DATASETS.register_module()
class MSRVTTRetrieval(BaseActionDataset):
"""MSR-VTT Retrieval dataset."""
[docs] def load_data_list(self) -> List[Dict]:
"""Load annotation file to get video information."""
exists(self.ann_file)
data_list = []
with open(self.ann_file) as f:
data_lines = json.load(f)
video_idx = 0
text_idx = 0
for data in data_lines:
# don't consider multiple videos or multiple captions
video_path = osp.join(self.data_prefix['video'], data['video'])
data_item = dict(
filename=video_path,
text=[],
gt_video_id=[],
gt_text_id=[])
if isinstance(data['caption'], str):
data['caption'] = [data['caption']]
for text in data['caption']:
text = pre_text(text)
data_item['text'].append(text)
data_item['gt_video_id'].append(video_idx)
data_item['gt_text_id'].append(text_idx)
text_idx += 1
video_idx += 1
data_list.append(data_item)
self.num_videos = video_idx
self.num_texts = text_idx
return data_list
def pre_text(text, max_l=None):
text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower())
text = text.replace('-', ' ').replace('/',
' ').replace('<person>', 'person')
text = re.sub(r'\s{2,}', ' ', text)
text = text.rstrip('\n').strip(' ')
if max_l: # truncate
words = text.split(' ')
if len(words) > max_l:
text = ' '.join(words[:max_l])
return text