Source code for mmaction.engine.optimizers.swin_optim_wrapper_constructor
# Copyright (c) OpenMMLab. All rights reserved.
from functools import reduce
from operator import mul
from typing import List
import torch.nn as nn
from mmengine.logging import print_log
from mmengine.optim import DefaultOptimWrapperConstructor
from mmaction.registry import OPTIM_WRAPPER_CONSTRUCTORS
[docs]@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class SwinOptimWrapperConstructor(DefaultOptimWrapperConstructor):
[docs] def add_params(self,
params: List[dict],
module: nn.Module,
prefix: str = 'base',
**kwargs) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module. Defaults to ``'base'``.
"""
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if not param.requires_grad:
params.append(param_group)
continue
param_group['lr'] = self.base_lr
if self.base_wd is not None:
param_group['weight_decay'] = self.base_wd
processing_keys = [
key for key in self.paramwise_cfg if key in f'{prefix}.{name}'
]
if processing_keys:
param_group['lr'] *= \
reduce(mul, [self.paramwise_cfg[key].get('lr_mult', 1.)
for key in processing_keys])
if self.base_wd is not None:
param_group['weight_decay'] *= \
reduce(mul, [self.paramwise_cfg[key].
get('decay_mult', 1.)
for key in processing_keys])
params.append(param_group)
for key, value in param_group.items():
if key == 'params':
continue
full_name = f'{prefix}.{name}' if prefix else name
print_log(
f'paramwise_options -- '
f'{full_name}: {key} = {round(value, 8)}',
logger='current')
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(params, child_mod, prefix=child_prefix)