Hui
requirements
8b98c60
import torch
import torchvision.models as models
from torch import nn
from collections import OrderedDict
def get_linear_layers(dimensions):
init_dim = dimensions[0]
dimensions = dimensions[1:]
if len(dimensions) < 1:
return []
layers = []
tmp_dim = init_dim
for i, d in enumerate(dimensions[:-1]):
layers.append((f"linear{i + 1}", nn.Linear(tmp_dim, d)))
layers.append((f"active{i + 1}", nn.ReLU()))
tmp_dim = d
layers.append((f"linear{len(dimensions)}", nn.Linear(tmp_dim, dimensions[-1])))
return layers
def num_flat_features(x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
class Cholec80Model(nn.Module):
def __init__(self, dimensions):
super(Cholec80Model, self).__init__()
# hyperparams
self.dimensions = dimensions
# CNN models
if "image" in self.dimensions:
self.model = models.resnet50(pretrained=True)
self.model.fc = nn.Identity()
# get img submodel
self.submodels = {}
# get info submodels
for key in self.dimensions.keys():
self.submodels[key] = nn.Sequential(OrderedDict(get_linear_layers(self.dimensions[key])))
# !!!register submodels to model
for key in self.submodels:
self.add_module(key, self.submodels[key])
# concat layers
dim_concat = 0
for key, ds in self.dimensions.items():
out_dim = ds[-1]
dim_concat += out_dim
self.last_layer = nn.Sequential(
nn.Linear(dim_concat, 7),
nn.LogSigmoid()
)
def forward(self, img_tensor, info_tensors):
concat_tensor = None
# image feature extraction
if "image" in self.dimensions:
out_feature = self.model(img_tensor)
concat_tensor = out_feature.clone()
concat_tensor = self.submodels["image"](concat_tensor)
concat_tensor = concat_tensor.view(-1, num_flat_features(concat_tensor))
# concat image_tensor with other info_tensors
for key, t in info_tensors.items():
t = self.submodels[key](t)
t = t.view(-1, num_flat_features(t))
if concat_tensor is None:
concat_tensor = t
else:
concat_tensor = torch.cat((concat_tensor, t), dim=1)
# last_layer
out_tensor = self.last_layer(concat_tensor)
# return results
return img_tensor, out_tensor