Spaces:
Runtime error
Runtime error
import json | |
import wandb | |
import torch | |
import torchmetrics | |
from torch import nn | |
import pytorch_lightning as pl | |
from torch.nn import functional as F | |
from timm import create_model as create_timm_model | |
from constants import INPUT_IMAGE_SIZE | |
pl.seed_everything(hash("setting random seeds") % 2**32 - 1) | |
class LitMLP(pl.LightningModule): | |
def __init__(self, batch_size, n_classes): | |
super().__init__() | |
self.batch_size = batch_size | |
self.feature_extractor, num_filters = get_feature_extractor() | |
self.classifier = nn.Linear(num_filters, n_classes) | |
self.save_hyperparameters() | |
self.train_acc = torchmetrics.Accuracy() | |
self.valid_acc = torchmetrics.Accuracy() | |
self.test_acc = torchmetrics.Accuracy() | |
self.img_class_map = get_img_class_map() | |
def forward(self, x): | |
self.feature_extractor.eval() | |
with torch.no_grad(): | |
representations = self.feature_extractor(x).flatten(1) | |
x = self.classifier(representations) | |
x = F.log_softmax(x, dim=1) | |
return x | |
def predict_app(self, x): | |
self.eval() | |
_, y_hat = self.forward(x).max(1) | |
return {'class_id': y_hat.item(), 'class_name': self.img_class_map[str(y_hat.item())]} | |
def loss(self, xs, ys): | |
logits = self(xs) | |
loss = F.nll_loss(logits, ys) | |
return logits, loss | |
def training_step(self, batch, batch_idx): | |
xs, ys = batch | |
logits, loss = self.loss(xs, ys) | |
preds = torch.argmax(logits, 1) | |
self.log('train/loss', loss, on_epoch=True) | |
self.train_acc(preds, ys) | |
self.log('train/acc', self.train_acc, on_epoch=True) | |
return loss | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"]) | |
def test_step(self, batch, batch_idx): | |
xs, ys = batch | |
logits, loss = self.loss(xs, ys) | |
preds = torch.argmax(logits, 1) | |
self.test_acc(preds, ys) | |
self.log("test/loss_epoch", loss, on_step=False, on_epoch=True) | |
self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True) | |
def test_epoch_end(self, test_step_outputs): # args are defined as part of pl API | |
dummy_input = torch.zeros((self.batch_size, *(3, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)), device=self.device) | |
model_filename = "model_final.onnx" | |
self.to_onnx(model_filename, dummy_input, export_params=True) | |
wandb.save(model_filename) | |
def validation_step(self, batch, batch_idx): | |
xs, ys = batch | |
logits, loss = self.loss(xs, ys) | |
preds = torch.argmax(logits, 1) | |
self.valid_acc(preds, ys) | |
self.log("valid/loss_epoch", loss) | |
self.log('valid/acc_epoch', self.valid_acc) | |
return logits | |
def validation_epoch_end(self, validation_step_outputs): | |
dummy_input = torch.zeros((self.batch_size, *(3, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)), | |
device=self.device) | |
model_filename = f"model_{str(self.global_step).zfill(5)}.onnx" | |
torch.onnx.export(self, dummy_input, 'latest_run' + model_filename, opset_version=11, | |
input_names=['input'], | |
output_names=['output'], | |
dynamic_axes={'input': {0: 'batch_size'}, | |
'output': {0: 'batch_size'}} | |
) | |
wandb.save(model_filename) | |
flattened_logits = torch.flatten(torch.cat(validation_step_outputs)) | |
self.logger.experiment.log( | |
{"valid/logits": wandb.Histogram(flattened_logits.to("cpu")), | |
"global_step": self.global_step}) | |
def get_img_class_map(): | |
with open('index_to_name.json') as f: | |
img_class_map = json.load(f) | |
return img_class_map | |
def get_feature_extractor(): | |
backbone = create_timm_model('resnet50d', pretrained=True) | |
num_filters = backbone.fc.in_features | |
layers = list(backbone.children())[:-1] | |
return nn.Sequential(*layers), num_filters | |