import torch from torch import nn from torchvision import transforms, models class ActionClassifier(nn.Module): def __init__(self, train_last_nlayer, hidden_size, dropout, ntargets): super().__init__() resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT, progress=True) modules = list(resnet.children())[:-1] # delete last layer self.resnet = nn.Sequential(*modules) for param in self.resnet[:-train_last_nlayer].parameters(): param.requires_grad = False self.fc = nn.Sequential( nn.Flatten(), nn.BatchNorm1d(resnet.fc.in_features), nn.Dropout(dropout), nn.Linear(resnet.fc.in_features, hidden_size), nn.ReLU(), nn.BatchNorm1d(hidden_size), nn.Dropout(dropout), nn.Linear(hidden_size, ntargets), nn.Sigmoid() ) def forward(self, x): x = self.resnet(x) x = self.fc(x) return x def get_transform(): transform = transforms.Compose([ transforms.Resize([224, 244]), models.ResNet50_Weights.DEFAULT.transforms() ]) return transform # def get_transform(): # transform = transforms.Compose([ # transforms.Resize([224, 244]), # transforms.ToTensor(), # # std multiply by 255 to convert img of [0, 255] # # to img of [0, 1] # transforms.Normalize((0.485, 0.456, 0.406), # (0.229*255, 0.224*255, 0.225*255))] # ) # return transform def get_model(): model = ActionClassifier(0, 512, 0.2, 15) model.load_state_dict(torch.load('./model_weights.pth', map_location=torch.device('cpu'))) return model def get_class(index): ind2cat = [ 'calling', 'clapping', 'cycling', 'dancing', 'drinking', 'eating', 'fighting', 'hugging', 'laughing', 'listening_to_music', 'running', 'sitting', 'sleeping', 'texting', 'using_laptop' ] return ind2cat[index] # img = Image.open('./inputs/Image_102.jpg').convert('RGB') # #print(transform(img)) # img = transform(img) # img = img.unsqueeze(dim=0) # print(img.shape) # model.eval() # with torch.no_grad(): # out = model(img) # out = nn.Softmax()(out).squeeze() # print(out.shape) # res = torch.argmax(out) # print(ind2cat[res])