Source code for mmaction.utils.gradcam_utils
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class GradCAM:
"""GradCAM class helps create visualization results.
Visualization results are blended by heatmaps and input images.
This class is modified from
https://github.com/facebookresearch/SlowFast/blob/master/slowfast/visualization/gradcam_utils.py # noqa
For more information about GradCAM, please visit:
https://arxiv.org/pdf/1610.02391.pdf
Args:
model (nn.Module): the recognizer model to be used.
target_layer_name (str): name of convolutional layer to
be used to get gradients and feature maps from for creating
localization maps.
colormap (str): matplotlib colormap used to create
heatmap. Defaults to 'viridis'. For more information, please visit
https://matplotlib.org/3.3.0/tutorials/colors/colormaps.html
"""
def __init__(self,
model: nn.Module,
target_layer_name: str,
colormap: str = 'viridis') -> None:
from ..models.recognizers import Recognizer2D, Recognizer3D
if isinstance(model, Recognizer2D):
self.is_recognizer2d = True
elif isinstance(model, Recognizer3D):
self.is_recognizer2d = False
else:
raise ValueError(
'GradCAM utils only support Recognizer2D & Recognizer3D.')
self.model = model
self.model.eval()
self.target_gradients = None
self.target_activations = None
import matplotlib.pyplot as plt
self.colormap = plt.get_cmap(colormap)
self._register_hooks(target_layer_name)
def _register_hooks(self, layer_name: str) -> None:
"""Register forward and backward hook to a layer, given layer_name, to
obtain gradients and activations.
Args:
layer_name (str): name of the layer.
"""
def get_gradients(module, grad_input, grad_output):
self.target_gradients = grad_output[0].detach()
def get_activations(module, input, output):
self.target_activations = output.clone().detach()
layer_ls = layer_name.split('/')
prev_module = self.model
for layer in layer_ls:
prev_module = prev_module._modules[layer]
target_layer = prev_module
target_layer.register_forward_hook(get_activations)
target_layer.register_backward_hook(get_gradients)
def _calculate_localization_map(self,
data: dict,
use_labels: bool,
delta=1e-20) -> tuple:
"""Calculate localization map for all inputs with Grad-CAM.
Args:
data (dict): model inputs, generated by test pipeline,
use_labels (bool): Whether to use given labels to generate
localization map.
delta (float): used in localization map normalization,
must be small enough. Please make sure
`localization_map_max - localization_map_min >> delta`
Returns:
localization_map (torch.Tensor): the localization map for
input imgs.
preds (torch.Tensor): Model predictions with shape
(batch_size, num_classes).
"""
inputs = data['inputs']
# use score before softmax
self.model.cls_head.average_clips = 'score'
# model forward & backward
results = self.model.test_step(data)
preds = [result.pred_score for result in results]
preds = torch.stack(preds)
if use_labels:
labels = [result.gt_label for result in results]
labels = torch.stack(labels)
score = torch.gather(preds, dim=1, index=labels)
else:
score = torch.max(preds, dim=-1)[0]
self.model.zero_grad()
score = torch.sum(score)
score.backward()
imgs = torch.stack(inputs)
if self.is_recognizer2d:
# [batch_size, num_segments, 3, H, W]
b, t, _, h, w = imgs.size()
else:
# [batch_size, num_crops*num_clips, 3, clip_len, H, W]
b1, b2, _, t, h, w = imgs.size()
b = b1 * b2
gradients = self.target_gradients
activations = self.target_activations
if self.is_recognizer2d:
# [B*Tg, C', H', W']
b_tg, c, _, _ = gradients.size()
tg = b_tg // b
else:
# source shape: [B, C', Tg, H', W']
_, c, tg, _, _ = gradients.size()
# target shape: [B, Tg, C', H', W']
gradients = gradients.permute(0, 2, 1, 3, 4)
activations = activations.permute(0, 2, 1, 3, 4)
# calculate & resize to [B, 1, T, H, W]
weights = torch.mean(gradients.view(b, tg, c, -1), dim=3)
weights = weights.view(b, tg, c, 1, 1)
activations = activations.view([b, tg, c] +
list(activations.size()[-2:]))
localization_map = torch.sum(
weights * activations, dim=2, keepdim=True)
localization_map = F.relu(localization_map)
localization_map = localization_map.permute(0, 2, 1, 3, 4)
localization_map = F.interpolate(
localization_map,
size=(t, h, w),
mode='trilinear',
align_corners=False)
# Normalize the localization map.
localization_map_min, localization_map_max = (
torch.min(localization_map.view(b, -1), dim=-1, keepdim=True)[0],
torch.max(localization_map.view(b, -1), dim=-1, keepdim=True)[0])
localization_map_min = torch.reshape(
localization_map_min, shape=(b, 1, 1, 1, 1))
localization_map_max = torch.reshape(
localization_map_max, shape=(b, 1, 1, 1, 1))
localization_map = (localization_map - localization_map_min) / (
localization_map_max - localization_map_min + delta)
localization_map = localization_map.data
return localization_map.squeeze(dim=1), preds
def _alpha_blending(self, localization_map: torch.Tensor,
input_imgs: torch.Tensor,
alpha: float) -> torch.Tensor:
"""Blend heatmaps and model input images and get visulization results.
Args:
localization_map (torch.Tensor): localization map for all inputs,
generated with Grad-CAM.
input_imgs (torch.Tensor): model inputs, raw images.
alpha (float): transparency level of the heatmap,
in the range [0, 1].
Returns:
torch.Tensor: blending results for localization map and input
images, with shape [B, T, H, W, 3] and pixel values in
RGB order within range [0, 1].
"""
# localization_map shape [B, T, H, W]
localization_map = localization_map.cpu()
# heatmap shape [B, T, H, W, 3] in RGB order
heatmap = self.colormap(localization_map.detach().numpy())
heatmap = heatmap[..., :3]
heatmap = torch.from_numpy(heatmap)
input_imgs = torch.stack(input_imgs)
# Permute input imgs to [B, T, H, W, 3], like heatmap
if self.is_recognizer2d:
# Recognizer2D input (B, T, C, H, W)
curr_inp = input_imgs.permute(0, 1, 3, 4, 2)
else:
# Recognizer3D input (B', num_clips*num_crops, C, T, H, W)
# B = B' * num_clips * num_crops
curr_inp = input_imgs.view([-1] + list(input_imgs.size()[2:]))
curr_inp = curr_inp.permute(0, 2, 3, 4, 1)
# renormalize input imgs to [0, 1]
curr_inp = curr_inp.cpu().float()
curr_inp /= 255.
# alpha blending
blended_imgs = alpha * heatmap + (1 - alpha) * curr_inp
return blended_imgs
def __call__(self,
data: dict,
use_labels: bool = False,
alpha: float = 0.5) -> tuple:
"""Visualize the localization maps on their corresponding inputs as
heatmap, using Grad-CAM.
Generate visualization results for **ALL CROPS**.
For example, for I3D model, if `clip_len=32, num_clips=10` and
use `ThreeCrop` in test pipeline, then for every model inputs,
there are 960(32*10*3) images generated.
Args:
data (dict): model inputs, generated by test pipeline.
use_labels (bool): Whether to use given labels to generate
localization map.
alpha (float): transparency level of the heatmap,
in the range [0, 1].
Returns:
blended_imgs (torch.Tensor): Visualization results, blended by
localization maps and model inputs.
preds (torch.Tensor): Model predictions for inputs.
"""
# localization_map shape [B, T, H, W]
# preds shape [batch_size, num_classes]
localization_map, preds = self._calculate_localization_map(
data, use_labels=use_labels)
# blended_imgs shape [B, T, H, W, 3]
blended_imgs = self._alpha_blending(localization_map, data['inputs'],
alpha)
# blended_imgs shape [B, T, H, W, 3]
# preds shape [batch_size, num_classes]
# Recognizer2D: B = batch_size, T = num_segments
# Recognizer3D: B = batch_size * num_crops * num_clips, T = clip_len
return blended_imgs, preds