WSCL / models /ensemble_model.py
yhzhai's picture
release code
482ab8a
raw
history blame
1.08 kB
from typing import Dict, List
import torch
import torch.nn as nn
class EnsembleModel(nn.Module):
def __init__(self, models: Dict, mvc_single_weight: Dict):
super().__init__()
self.sub_models = nn.ModuleDict(models)
self.modality = list(self.sub_models.keys())
self.mvc_single_weight = mvc_single_weight
for k, v in self.mvc_single_weight.items():
assert 0 <= v <= 1, "The weight of {} for {} is out of range".format(v, k)
def forward(self, image, seg_size=None):
result = {}
for modality in self.modality:
result[modality] = self.sub_models[modality](image, seg_size)
avg_result = {}
for k in result[self.modality[0]].keys():
avg_result[k] = torch.zeros_like(result[self.modality[0]][k])
for modality in self.modality:
avg_result[k] = (
avg_result[k]
+ self.mvc_single_weight[modality] * result[modality][k]
)
result["ensemble"] = avg_result
return result