Source code for mmaction.engine.hooks.output
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import warnings
import torch
[docs]class OutputHook:
    """Output feature map of some layers.
    Args:
        module (nn.Module): The whole module to get layers.
        outputs (tuple[str] | list[str]): Layer name to output. Default: None.
        as_tensor (bool): Determine to return a tensor or a numpy array.
            Default: False.
    """
    def __init__(self, module, outputs=None, as_tensor=False):
        self.outputs = outputs
        self.as_tensor = as_tensor
        self.layer_outputs = {}
        self.handles = []
        self.register(module)
    def register(self, module):
        def hook_wrapper(name):
            def hook(model, input, output):
                if not isinstance(output, torch.Tensor):
                    warnings.warn(f'Directly return the output from {name}, '
                                  f'since it is not a tensor')
                    self.layer_outputs[name] = output
                elif self.as_tensor:
                    self.layer_outputs[name] = output
                else:
                    self.layer_outputs[name] = output.detach().cpu().numpy()
            return hook
        if isinstance(self.outputs, (list, tuple)):
            for name in self.outputs:
                try:
                    layer = rgetattr(module, name)
                    h = layer.register_forward_hook(hook_wrapper(name))
                except AttributeError:
                    raise AttributeError(f'Module {name} not found')
                self.handles.append(h)
    def remove(self):
        for h in self.handles:
            h.remove()
    def __enter__(self):
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.remove()
# using wonder's beautiful simplification:
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))