File size: 1,079 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Dict, List

import torch
import torch.nn as nn


class EnsembleModel(nn.Module):
    def __init__(self, models: Dict, mvc_single_weight: Dict):
        super().__init__()

        self.sub_models = nn.ModuleDict(models)
        self.modality = list(self.sub_models.keys())
        self.mvc_single_weight = mvc_single_weight
        for k, v in self.mvc_single_weight.items():
            assert 0 <= v <= 1, "The weight of {} for {} is out of range".format(v, k)

    def forward(self, image, seg_size=None):
        result = {}
        for modality in self.modality:
            result[modality] = self.sub_models[modality](image, seg_size)

        avg_result = {}
        for k in result[self.modality[0]].keys():
            avg_result[k] = torch.zeros_like(result[self.modality[0]][k])
            for modality in self.modality:
                avg_result[k] = (
                    avg_result[k]
                    + self.mvc_single_weight[modality] * result[modality][k]
                )
        result["ensemble"] = avg_result

        return result