Source code for mmaction.engine.hooks.visualization_hook
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os.path as osp
from typing import Optional, Sequence
from mmengine import FileClient
from mmengine.hooks import Hook
from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer
from mmaction.registry import HOOKS
from mmaction.structures import ActionDataSample
[docs]@HOOKS.register_module()
class VisualizationHook(Hook):
"""Classification Visualization Hook. Used to visualize validation and
testing prediction results.
- If ``out_dir`` is specified, all storage backends are ignored
and save the image to the ``out_dir``.
- If ``show`` is True, plot the result image in a window, please
confirm you are able to access the graphical interface.
Args:
enable (bool): Whether to enable this hook. Defaults to False.
interval (int): The interval of samples to visualize. Defaults to 5000.
show (bool): Whether to display the drawn image. Defaults to False.
out_dir (str, optional): directory where painted images will be saved
in the testing process. If None, handle with the backends of the
visualizer. Defaults to None.
**kwargs: other keyword arguments of
:meth:`mmcls.visualization.ClsVisualizer.add_datasample`.
"""
def __init__(self,
enable=False,
interval: int = 5000,
show: bool = False,
out_dir: Optional[str] = None,
**kwargs):
self._visualizer: Visualizer = Visualizer.get_current_instance()
self.enable = enable
self.interval = interval
self.show = show
self.out_dir = out_dir
if out_dir is not None:
self.file_client = FileClient.infer_client(uri=out_dir)
else:
self.file_client = None
self.draw_args = {**kwargs, 'show': show}
def _draw_samples(self,
batch_idx: int,
data_batch: dict,
data_samples: Sequence[ActionDataSample],
step: int = 0) -> None:
"""Visualize every ``self.interval`` samples from a data batch.
Args:
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ActionDataSample`]): Outputs from model.
step (int): Global step value to record. Defaults to 0.
"""
if self.enable is False:
return
batch_size = len(data_samples)
videos = data_batch['inputs']
start_idx = batch_size * batch_idx
end_idx = start_idx + batch_size
# The first index divisible by the interval, after the start index
first_sample_id = math.ceil(start_idx / self.interval) * self.interval
for sample_id in range(first_sample_id, end_idx, self.interval):
video = videos[sample_id - start_idx]
# move channel to the last
video = video.permute(1, 2, 3, 0).numpy().astype('uint8')
data_sample = data_samples[sample_id - start_idx]
if 'filename' in data_sample:
# osp.basename works on different platforms even file clients.
sample_name = osp.basename(data_sample.get('filename'))
elif 'frame_dir' in data_sample:
sample_name = osp.basename(data_sample.get('frame_dir'))
else:
sample_name = str(sample_id)
draw_args = self.draw_args
if self.out_dir is not None:
draw_args['out_path'] = self.file_client.join_path(
self.out_dir, f'{sample_name}_{step}')
self._visualizer.add_datasample(
sample_name,
video=video,
data_sample=data_sample,
step=step,
**self.draw_args,
)
[docs] def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ActionDataSample]) -> None:
"""Visualize every ``self.interval`` samples during validation.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ActionDataSample`]): Outputs from model.
"""
if isinstance(runner.train_loop, EpochBasedTrainLoop):
step = runner.epoch
else:
step = runner.iter
self._draw_samples(batch_idx, data_batch, outputs, step=step)
[docs] def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ActionDataSample]) -> None:
"""Visualize every ``self.interval`` samples during test.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
"""
self._draw_samples(batch_idx, data_batch, outputs, step=0)