KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
1.18 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine.model import BaseTTAModel
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class AverageClsScoreTTA(BaseTTAModel):
def merge_preds(
self,
data_samples_list: List[List[DataSample]],
) -> List[DataSample]:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[List[DataSample]]): List of predictions
of all enhanced data.
Returns:
List[DataSample]: Merged prediction.
"""
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples
def _merge_single_sample(self, data_samples):
merged_data_sample: DataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
return merged_data_sample