from typing import Dict, List import torch import torch.nn as nn class EnsembleModel(nn.Module): def __init__(self, models: Dict, mvc_single_weight: Dict): super().__init__() self.sub_models = nn.ModuleDict(models) self.modality = list(self.sub_models.keys()) self.mvc_single_weight = mvc_single_weight for k, v in self.mvc_single_weight.items(): assert 0 <= v <= 1, "The weight of {} for {} is out of range".format(v, k) def forward(self, image, seg_size=None): result = {} for modality in self.modality: result[modality] = self.sub_models[modality](image, seg_size) avg_result = {} for k in result[self.modality[0]].keys(): avg_result[k] = torch.zeros_like(result[self.modality[0]][k]) for modality in self.modality: avg_result[k] = ( avg_result[k] + self.mvc_single_weight[modality] * result[modality][k] ) result["ensemble"] = avg_result return result