# Copyright (c) OpenMMLab. All rights reserved.importfunctoolsimportwarningsimporttorch
[docs]classOutputHook:"""Output feature map of some layers. Args: module (nn.Module): The whole module to get layers. outputs (tuple[str] | list[str]): Layer name to output. Default: None. as_tensor (bool): Determine to return a tensor or a numpy array. Default: False. """def__init__(self,module,outputs=None,as_tensor=False):self.outputs=outputsself.as_tensor=as_tensorself.layer_outputs={}self.handles=[]self.register(module)defregister(self,module):defhook_wrapper(name):defhook(model,input,output):ifnotisinstance(output,torch.Tensor):warnings.warn(f'Directly return the output from {name}, 'f'since it is not a tensor')self.layer_outputs[name]=outputelifself.as_tensor:self.layer_outputs[name]=outputelse:self.layer_outputs[name]=output.detach().cpu().numpy()returnhookifisinstance(self.outputs,(list,tuple)):fornameinself.outputs:try:layer=rgetattr(module,name)h=layer.register_forward_hook(hook_wrapper(name))exceptAttributeError:raiseAttributeError(f'Module {name} not found')self.handles.append(h)defremove(self):forhinself.handles:h.remove()def__enter__(self):returnselfdef__exit__(self,exc_type,exc_val,exc_tb):self.remove()
# using wonder's beautiful simplification:# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objectsdefrgetattr(obj,attr,*args):def_getattr(obj,attr):returngetattr(obj,attr,*args)returnfunctools.reduce(_getattr,[obj]+attr.split('.'))