Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List | |
from mmengine.model import BaseTTAModel | |
from mmpretrain.registry import MODELS | |
from mmpretrain.structures import DataSample | |
class AverageClsScoreTTA(BaseTTAModel): | |
def merge_preds( | |
self, | |
data_samples_list: List[List[DataSample]], | |
) -> List[DataSample]: | |
"""Merge predictions of enhanced data to one prediction. | |
Args: | |
data_samples_list (List[List[DataSample]]): List of predictions | |
of all enhanced data. | |
Returns: | |
List[DataSample]: Merged prediction. | |
""" | |
merged_data_samples = [] | |
for data_samples in data_samples_list: | |
merged_data_samples.append(self._merge_single_sample(data_samples)) | |
return merged_data_samples | |
def _merge_single_sample(self, data_samples): | |
merged_data_sample: DataSample = data_samples[0].new() | |
merged_score = sum(data_sample.pred_score | |
for data_sample in data_samples) / len(data_samples) | |
merged_data_sample.set_pred_score(merged_score) | |
return merged_data_sample | |