Source code for mmaction.engine.runner.retrieval_loop
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import is_model_wrapper
from mmengine.runner import TestLoop, ValLoop, autocast
from mmaction.registry import LOOPS
[docs]@LOOPS.register_module()
class RetrievalValLoop(ValLoop):
    """Loop for multimodal retrieval val.
    Args:
        runner (Runner): A reference of runner.
        dataloader (Dataloader or dict): A dataloader object or a dict to
            build a dataloader.
        evaluator (Evaluator or dict or list): Used for computing metrics.
        fp16 (bool): Whether to enable fp16 valing. Defaults to
            False.
    """
[docs]    def run(self) -> dict:
        """Launch val."""
        self.runner.call_hook('before_val')
        self.runner.call_hook('before_val_epoch')
        self.runner.model.eval()
        feats_local = []
        data_samples_local = []
        for idx, data_batch in enumerate(self.dataloader):
            with torch.no_grad():
                self.runner.call_hook(
                    'before_val_iter', batch_idx=idx, data_batch=data_batch)
                # predictions should be sequence of BaseDataElement
                with autocast(enabled=self.fp16):
                    if is_model_wrapper(self.runner.model):
                        data_preprocessor = self.runner.model.module.data_preprocessor  # noqa: E501
                    else:
                        data_preprocessor = self.runner.model.data_preprocessor
                    # get features for retrieval instead of data samples
                    data_batch = data_preprocessor(data_batch, False)
                    feats = self.runner.model._run_forward(
                        data_batch, mode='tensor')
                    feats_local.append(feats)
                    data_samples_local.extend(data_batch['data_samples'])
                self.runner.call_hook(
                    'after_val_iter',
                    batch_idx=idx,
                    data_batch=data_batch,
                    outputs=feats)
        # concatenate different features
        feats_local = {
            k: torch.cat([dic[k] for dic in feats_local])
            for k in feats_local[0]
        }
        # get predictions
        if is_model_wrapper(self.runner.model):
            predict_all_fn = self.runner.model.module.predict_all
        else:
            predict_all_fn = self.runner.model.predict_all
        num_videos = self.dataloader.dataset.num_videos
        num_texts = self.dataloader.dataset.num_texts
        with torch.no_grad():
            with autocast(enabled=self.fp16):
                i2t_data_samples, t2i_data_samples = predict_all_fn(
                    feats_local,
                    data_samples_local,
                    num_images=num_videos,
                    num_texts=num_texts,
                )
        # process in evaluator and compute metrics
        self.evaluator.process(i2t_data_samples, None)
        i2t_metrics = self.evaluator.evaluate(num_videos)
        i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()}
        self.evaluator.process(t2i_data_samples, None)
        t2i_metrics = self.evaluator.evaluate(num_texts)
        t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()}
        metrics = {**i2t_metrics, **t2i_metrics}
        self.runner.call_hook('after_val_epoch', metrics=metrics)
        self.runner.call_hook('after_val')
        return metrics
[docs]@LOOPS.register_module()
class RetrievalTestLoop(TestLoop):
    """Loop for multimodal retrieval test.
    Args:
        runner (Runner): A reference of runner.
        dataloader (Dataloader or dict): A dataloader object or a dict to
            build a dataloader.
        evaluator (Evaluator or dict or list): Used for computing metrics.
        fp16 (bool): Whether to enable fp16 testing. Defaults to
            False.
    """
[docs]    def run(self) -> dict:
        """Launch test."""
        self.runner.call_hook('before_test')
        self.runner.call_hook('before_test_epoch')
        self.runner.model.eval()
        feats_local = []
        data_samples_local = []
        for idx, data_batch in enumerate(self.dataloader):
            with torch.no_grad():
                self.runner.call_hook(
                    'before_test_iter', batch_idx=idx, data_batch=data_batch)
                # predictions should be sequence of BaseDataElement
                with autocast(enabled=self.fp16):
                    if is_model_wrapper(self.runner.model):
                        data_preprocessor = self.runner.model.module.data_preprocessor  # noqa: E501
                    else:
                        data_preprocessor = self.runner.model.data_preprocessor
                    # get features for retrieval instead of data samples
                    data_batch = data_preprocessor(data_batch, False)
                    feats = self.runner.model._run_forward(
                        data_batch, mode='tensor')
                    feats_local.append(feats)
                    data_samples_local.extend(data_batch['data_samples'])
                self.runner.call_hook(
                    'after_test_iter',
                    batch_idx=idx,
                    data_batch=data_batch,
                    outputs=feats)
        # concatenate different features
        feats_local = {
            k: torch.cat([dic[k] for dic in feats_local])
            for k in feats_local[0]
        }
        # get predictions
        if is_model_wrapper(self.runner.model):
            predict_all_fn = self.runner.model.module.predict_all
        else:
            predict_all_fn = self.runner.model.predict_all
        num_videos = self.dataloader.dataset.num_videos
        num_texts = self.dataloader.dataset.num_texts
        with torch.no_grad():
            with autocast(enabled=self.fp16):
                i2t_data_samples, t2i_data_samples = predict_all_fn(
                    feats_local,
                    data_samples_local,
                    num_images=num_videos,
                    num_texts=num_texts,
                )
        # process in evaluator and compute metrics
        self.evaluator.process(i2t_data_samples, None)
        i2t_metrics = self.evaluator.evaluate(num_videos)
        i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()}
        self.evaluator.process(t2i_data_samples, None)
        t2i_metrics = self.evaluator.evaluate(num_texts)
        t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()}
        metrics = {**i2t_metrics, **t2i_metrics}
        self.runner.call_hook('after_test_epoch', metrics=metrics)
        self.runner.call_hook('after_test')
        return metrics