Source code for mmaction.engine.runner.multi_loop
# Copyright (c) OpenMMLab. All rights reserved.
import gc
from typing import Dict, List, Union
from mmengine.runner import EpochBasedTrainLoop
from torch.utils.data import DataLoader
from mmaction.registry import LOOPS
class EpochMultiLoader:
"""Multi loaders based on epoch."""
def __init__(self, dataloaders: List[DataLoader]):
self._dataloaders = dataloaders
self.iter_loaders = [iter(loader) for loader in self._dataloaders]
@property
def num_loaders(self):
"""The number of dataloaders."""
return len(self._dataloaders)
def __iter__(self):
"""Return self when executing __iter__."""
return self
def __next__(self):
"""Get the next iter's data of multiple loaders."""
data = tuple([next(loader) for loader in self.iter_loaders])
return data
def __len__(self):
"""Get the length of loader."""
return min([len(loader) for loader in self._dataloaders])
[docs]@LOOPS.register_module()
class MultiLoaderEpochBasedTrainLoop(EpochBasedTrainLoop):
"""EpochBasedTrainLoop with multiple dataloaders.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or Dict): A dataloader object or a dict to
build a dataloader for training the model.
other_loaders (List of Dataloader or Dict): A list of other loaders.
Each item in the list is a dataloader object or a dict to build
a dataloader.
max_epochs (int): Total training epochs.
val_begin (int): The epoch that begins validating. Defaults to 1.
val_interval (int): Validation interval. Defaults to 1.
"""
def __init__(self,
runner,
dataloader: Union[Dict, DataLoader],
other_loaders: List[Union[Dict, DataLoader]],
max_epochs: int,
val_begin: int = 1,
val_interval: int = 1) -> None:
super().__init__(runner, dataloader, max_epochs, val_begin,
val_interval)
multi_loaders = [self.dataloader]
for loader in other_loaders:
if isinstance(loader, dict):
loader = runner.build_dataloader(loader, seed=runner.seed)
multi_loaders.append(loader)
self.multi_loaders = multi_loaders
[docs] def run_epoch(self) -> None:
"""Iterate one epoch."""
self.runner.call_hook('before_train_epoch')
self.runner.model.train()
gc.collect()
for loader in self.multi_loaders:
if hasattr(loader, 'sampler') and hasattr(loader.sampler,
'set_epoch'):
loader.sampler.set_epoch(self._epoch)
for idx, data_batch in enumerate(EpochMultiLoader(self.multi_loaders)):
self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch')
self._epoch += 1