Spaces:
Runtime error
Runtime error
import torch | |
import torchvision.models as models | |
from torch import nn | |
from collections import OrderedDict | |
def get_linear_layers(dimensions): | |
init_dim = dimensions[0] | |
dimensions = dimensions[1:] | |
if len(dimensions) < 1: | |
return [] | |
layers = [] | |
tmp_dim = init_dim | |
for i, d in enumerate(dimensions[:-1]): | |
layers.append((f"linear{i + 1}", nn.Linear(tmp_dim, d))) | |
layers.append((f"active{i + 1}", nn.ReLU())) | |
tmp_dim = d | |
layers.append((f"linear{len(dimensions)}", nn.Linear(tmp_dim, dimensions[-1]))) | |
return layers | |
def num_flat_features(x): | |
size = x.size()[1:] | |
num_features = 1 | |
for s in size: | |
num_features *= s | |
return num_features | |
class Cholec80Model(nn.Module): | |
def __init__(self, dimensions): | |
super(Cholec80Model, self).__init__() | |
# hyperparams | |
self.dimensions = dimensions | |
# CNN models | |
if "image" in self.dimensions: | |
self.model = models.resnet50(pretrained=True) | |
self.model.fc = nn.Identity() | |
# get img submodel | |
self.submodels = {} | |
# get info submodels | |
for key in self.dimensions.keys(): | |
self.submodels[key] = nn.Sequential(OrderedDict(get_linear_layers(self.dimensions[key]))) | |
# !!!register submodels to model | |
for key in self.submodels: | |
self.add_module(key, self.submodels[key]) | |
# concat layers | |
dim_concat = 0 | |
for key, ds in self.dimensions.items(): | |
out_dim = ds[-1] | |
dim_concat += out_dim | |
self.last_layer = nn.Sequential( | |
nn.Linear(dim_concat, 7), | |
nn.LogSigmoid() | |
) | |
def forward(self, img_tensor, info_tensors): | |
concat_tensor = None | |
# image feature extraction | |
if "image" in self.dimensions: | |
out_feature = self.model(img_tensor) | |
concat_tensor = out_feature.clone() | |
concat_tensor = self.submodels["image"](concat_tensor) | |
concat_tensor = concat_tensor.view(-1, num_flat_features(concat_tensor)) | |
# concat image_tensor with other info_tensors | |
for key, t in info_tensors.items(): | |
t = self.submodels[key](t) | |
t = t.view(-1, num_flat_features(t)) | |
if concat_tensor is None: | |
concat_tensor = t | |
else: | |
concat_tensor = torch.cat((concat_tensor, t), dim=1) | |
# last_layer | |
out_tensor = self.last_layer(concat_tensor) | |
# return results | |
return img_tensor, out_tensor | |