Shortcuts

Source code for mmaction.evaluation.metrics.ava_metric

# Copyright (c) OpenMMLab. All rights reserved.
import os
from datetime import datetime
from typing import Any, List, Optional, Sequence, Tuple

from mmengine.evaluator import BaseMetric

from mmaction.evaluation import ava_eval, results2csv
from mmaction.registry import METRICS
from mmaction.structures import bbox2result


[docs]@METRICS.register_module() class AVAMetric(BaseMetric): """AVA evaluation metric.""" default_prefix: Optional[str] = 'mAP' def __init__(self, ann_file: str, exclude_file: str, label_file: str, options: Tuple[str] = ('mAP', ), action_thr: float = 0.002, num_classes: int = 81, custom_classes: Optional[List[int]] = None, collect_device: str = 'cpu', prefix: Optional[str] = None): super().__init__(collect_device=collect_device, prefix=prefix) assert len(options) == 1 self.ann_file = ann_file self.exclude_file = exclude_file self.label_file = label_file self.num_classes = num_classes self.options = options self.action_thr = action_thr self.custom_classes = custom_classes if custom_classes is not None: self.custom_classes = list([0] + custom_classes)
[docs] def process(self, data_batch: Sequence[Tuple[Any, dict]], data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Sequence[Tuple[Any, dict]]): A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ for data_sample in data_samples: result = dict() pred = data_sample['pred_instances'] result['video_id'] = data_sample['video_id'] result['timestamp'] = data_sample['timestamp'] outputs = bbox2result( pred['bboxes'], pred['scores'], num_classes=self.num_classes, thr=self.action_thr) result['outputs'] = outputs self.results.append(result)
[docs] def compute_metrics(self, results: list) -> dict: """Compute the metrics from processed results. Args: results (list): The processed results of each batch. Returns: dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ time_now = datetime.now().strftime('%Y%m%d_%H%M%S') temp_file = f'AVA_{time_now}_result.csv' results2csv(results, temp_file, self.custom_classes) eval_results = ava_eval( temp_file, self.options[0], self.label_file, self.ann_file, self.exclude_file, ignore_empty_frames=True, custom_classes=self.custom_classes) os.remove(temp_file) return eval_results