File size: 2,568 Bytes
ea774f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

import torch
import torchvision.models as models
from torch import nn
from collections import OrderedDict


def get_linear_layers(dimensions):
    init_dim = dimensions[0]
    dimensions = dimensions[1:]
    if len(dimensions) < 1:
        return []
    layers = []
    tmp_dim = init_dim
    for i, d in enumerate(dimensions[:-1]):
        layers.append((f"linear{i + 1}", nn.Linear(tmp_dim, d)))
        layers.append((f"active{i + 1}", nn.ReLU()))
        tmp_dim = d
    layers.append((f"linear{len(dimensions)}", nn.Linear(tmp_dim, dimensions[-1])))
    return layers


def num_flat_features(x):
    size = x.size()[1:]
    num_features = 1
    for s in size:
        num_features *= s
    return num_features


class Cholec80Model(nn.Module):
    def __init__(self, dimensions):
        super(Cholec80Model, self).__init__()
        # hyperparams
        self.dimensions = dimensions
        # CNN models
        if "image" in self.dimensions:
            self.model = models.resnet50(pretrained=True)
            self.model.fc = nn.Identity()
        # get img submodel
        self.submodels = {}
        # get info submodels
        for key in self.dimensions.keys():
            self.submodels[key] = nn.Sequential(OrderedDict(get_linear_layers(self.dimensions[key])))
        # !!!register submodels to model
        for key in self.submodels:
            self.add_module(key, self.submodels[key])
        # concat layers
        dim_concat = 0
        for key, ds in self.dimensions.items():
            out_dim = ds[-1]
            dim_concat += out_dim
        self.last_layer = nn.Sequential(
            nn.Linear(dim_concat, 7),
            nn.LogSigmoid()
        )

    def forward(self, img_tensor, info_tensors):
        concat_tensor = None
        # image feature extraction
        if "image" in self.dimensions:
            out_feature = self.model(img_tensor)
            concat_tensor = out_feature.clone()
            concat_tensor = self.submodels["image"](concat_tensor)
            concat_tensor = concat_tensor.view(-1, num_flat_features(concat_tensor))
        # concat image_tensor with other info_tensors
        for key, t in info_tensors.items():
            t = self.submodels[key](t)
            t = t.view(-1, num_flat_features(t))
            if concat_tensor is None:
                concat_tensor = t
            else:
                concat_tensor = torch.cat((concat_tensor, t), dim=1)
        # last_layer
        out_tensor = self.last_layer(concat_tensor)
        # return results
        return img_tensor, out_tensor