import torch import torchvision.models as models from torch import nn from collections import OrderedDict def get_linear_layers(dimensions): init_dim = dimensions[0] dimensions = dimensions[1:] if len(dimensions) < 1: return [] layers = [] tmp_dim = init_dim for i, d in enumerate(dimensions[:-1]): layers.append((f"linear{i + 1}", nn.Linear(tmp_dim, d))) layers.append((f"active{i + 1}", nn.ReLU())) tmp_dim = d layers.append((f"linear{len(dimensions)}", nn.Linear(tmp_dim, dimensions[-1]))) return layers def num_flat_features(x): size = x.size()[1:] num_features = 1 for s in size: num_features *= s return num_features class Cholec80Model(nn.Module): def __init__(self, dimensions): super(Cholec80Model, self).__init__() # hyperparams self.dimensions = dimensions # CNN models if "image" in self.dimensions: self.model = models.resnet50(pretrained=True) self.model.fc = nn.Identity() # get img submodel self.submodels = {} # get info submodels for key in self.dimensions.keys(): self.submodels[key] = nn.Sequential(OrderedDict(get_linear_layers(self.dimensions[key]))) # !!!register submodels to model for key in self.submodels: self.add_module(key, self.submodels[key]) # concat layers dim_concat = 0 for key, ds in self.dimensions.items(): out_dim = ds[-1] dim_concat += out_dim self.last_layer = nn.Sequential( nn.Linear(dim_concat, 7), nn.LogSigmoid() ) def forward(self, img_tensor, info_tensors): concat_tensor = None # image feature extraction if "image" in self.dimensions: out_feature = self.model(img_tensor) concat_tensor = out_feature.clone() concat_tensor = self.submodels["image"](concat_tensor) concat_tensor = concat_tensor.view(-1, num_flat_features(concat_tensor)) # concat image_tensor with other info_tensors for key, t in info_tensors.items(): t = self.submodels[key](t) t = t.view(-1, num_flat_features(t)) if concat_tensor is None: concat_tensor = t else: concat_tensor = torch.cat((concat_tensor, t), dim=1) # last_layer out_tensor = self.last_layer(concat_tensor) # return results return img_tensor, out_tensor