Shortcuts

Source code for mmaction.engine.runner.retrieval_loop

# Copyright (c) OpenMMLab. All rights reserved.

import torch
from mmengine.model import is_model_wrapper
from mmengine.runner import TestLoop, ValLoop, autocast

from mmaction.registry import LOOPS


[docs]@LOOPS.register_module() class RetrievalValLoop(ValLoop): """Loop for multimodal retrieval val. Args: runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. evaluator (Evaluator or dict or list): Used for computing metrics. fp16 (bool): Whether to enable fp16 valing. Defaults to False. """
[docs] def run(self) -> dict: """Launch val.""" self.runner.call_hook('before_val') self.runner.call_hook('before_val_epoch') self.runner.model.eval() feats_local = [] data_samples_local = [] for idx, data_batch in enumerate(self.dataloader): with torch.no_grad(): self.runner.call_hook( 'before_val_iter', batch_idx=idx, data_batch=data_batch) # predictions should be sequence of BaseDataElement with autocast(enabled=self.fp16): if is_model_wrapper(self.runner.model): data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 else: data_preprocessor = self.runner.model.data_preprocessor # get features for retrieval instead of data samples data_batch = data_preprocessor(data_batch, False) feats = self.runner.model._run_forward( data_batch, mode='tensor') feats_local.append(feats) data_samples_local.extend(data_batch['data_samples']) self.runner.call_hook( 'after_val_iter', batch_idx=idx, data_batch=data_batch, outputs=feats) # concatenate different features feats_local = { k: torch.cat([dic[k] for dic in feats_local]) for k in feats_local[0] } # get predictions if is_model_wrapper(self.runner.model): predict_all_fn = self.runner.model.module.predict_all else: predict_all_fn = self.runner.model.predict_all num_videos = self.dataloader.dataset.num_videos num_texts = self.dataloader.dataset.num_texts with torch.no_grad(): with autocast(enabled=self.fp16): i2t_data_samples, t2i_data_samples = predict_all_fn( feats_local, data_samples_local, num_images=num_videos, num_texts=num_texts, ) # process in evaluator and compute metrics self.evaluator.process(i2t_data_samples, None) i2t_metrics = self.evaluator.evaluate(num_videos) i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} self.evaluator.process(t2i_data_samples, None) t2i_metrics = self.evaluator.evaluate(num_texts) t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} metrics = {**i2t_metrics, **t2i_metrics} self.runner.call_hook('after_val_epoch', metrics=metrics) self.runner.call_hook('after_val') return metrics
[docs]@LOOPS.register_module() class RetrievalTestLoop(TestLoop): """Loop for multimodal retrieval test. Args: runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. evaluator (Evaluator or dict or list): Used for computing metrics. fp16 (bool): Whether to enable fp16 testing. Defaults to False. """
[docs] def run(self) -> dict: """Launch test.""" self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') self.runner.model.eval() feats_local = [] data_samples_local = [] for idx, data_batch in enumerate(self.dataloader): with torch.no_grad(): self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) # predictions should be sequence of BaseDataElement with autocast(enabled=self.fp16): if is_model_wrapper(self.runner.model): data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 else: data_preprocessor = self.runner.model.data_preprocessor # get features for retrieval instead of data samples data_batch = data_preprocessor(data_batch, False) feats = self.runner.model._run_forward( data_batch, mode='tensor') feats_local.append(feats) data_samples_local.extend(data_batch['data_samples']) self.runner.call_hook( 'after_test_iter', batch_idx=idx, data_batch=data_batch, outputs=feats) # concatenate different features feats_local = { k: torch.cat([dic[k] for dic in feats_local]) for k in feats_local[0] } # get predictions if is_model_wrapper(self.runner.model): predict_all_fn = self.runner.model.module.predict_all else: predict_all_fn = self.runner.model.predict_all num_videos = self.dataloader.dataset.num_videos num_texts = self.dataloader.dataset.num_texts with torch.no_grad(): with autocast(enabled=self.fp16): i2t_data_samples, t2i_data_samples = predict_all_fn( feats_local, data_samples_local, num_images=num_videos, num_texts=num_texts, ) # process in evaluator and compute metrics self.evaluator.process(i2t_data_samples, None) i2t_metrics = self.evaluator.evaluate(num_videos) i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} self.evaluator.process(t2i_data_samples, None) t2i_metrics = self.evaluator.evaluate(num_texts) t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} metrics = {**i2t_metrics, **t2i_metrics} self.runner.call_hook('after_test_epoch', metrics=metrics) self.runner.call_hook('after_test') return metrics