|
from typing import Dict, List |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class EnsembleModel(nn.Module): |
|
def __init__(self, models: Dict, mvc_single_weight: Dict): |
|
super().__init__() |
|
|
|
self.sub_models = nn.ModuleDict(models) |
|
self.modality = list(self.sub_models.keys()) |
|
self.mvc_single_weight = mvc_single_weight |
|
for k, v in self.mvc_single_weight.items(): |
|
assert 0 <= v <= 1, "The weight of {} for {} is out of range".format(v, k) |
|
|
|
def forward(self, image, seg_size=None): |
|
result = {} |
|
for modality in self.modality: |
|
result[modality] = self.sub_models[modality](image, seg_size) |
|
|
|
avg_result = {} |
|
for k in result[self.modality[0]].keys(): |
|
avg_result[k] = torch.zeros_like(result[self.modality[0]][k]) |
|
for modality in self.modality: |
|
avg_result[k] = ( |
|
avg_result[k] |
|
+ self.mvc_single_weight[modality] * result[modality][k] |
|
) |
|
result["ensemble"] = avg_result |
|
|
|
return result |
|
|