birds-israel / lit_mlp.py
Tom Beer
add training
fdb3349
import json
import wandb
import torch
import torchmetrics
from torch import nn
import pytorch_lightning as pl
from torch.nn import functional as F
from timm import create_model as create_timm_model
from constants import INPUT_IMAGE_SIZE
pl.seed_everything(hash("setting random seeds") % 2**32 - 1)
class LitMLP(pl.LightningModule):
def __init__(self, batch_size, n_classes):
super().__init__()
self.batch_size = batch_size
self.feature_extractor, num_filters = get_feature_extractor()
self.classifier = nn.Linear(num_filters, n_classes)
self.save_hyperparameters()
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
self.img_class_map = get_img_class_map()
def forward(self, x):
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
x = self.classifier(representations)
x = F.log_softmax(x, dim=1)
return x
def predict_app(self, x):
self.eval()
_, y_hat = self.forward(x).max(1)
return {'class_id': y_hat.item(), 'class_name': self.img_class_map[str(y_hat.item())]}
def loss(self, xs, ys):
logits = self(xs)
loss = F.nll_loss(logits, ys)
return logits, loss
def training_step(self, batch, batch_idx):
xs, ys = batch
logits, loss = self.loss(xs, ys)
preds = torch.argmax(logits, 1)
self.log('train/loss', loss, on_epoch=True)
self.train_acc(preds, ys)
self.log('train/acc', self.train_acc, on_epoch=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])
def test_step(self, batch, batch_idx):
xs, ys = batch
logits, loss = self.loss(xs, ys)
preds = torch.argmax(logits, 1)
self.test_acc(preds, ys)
self.log("test/loss_epoch", loss, on_step=False, on_epoch=True)
self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True)
def test_epoch_end(self, test_step_outputs): # args are defined as part of pl API
dummy_input = torch.zeros((self.batch_size, *(3, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)), device=self.device)
model_filename = "model_final.onnx"
self.to_onnx(model_filename, dummy_input, export_params=True)
wandb.save(model_filename)
def validation_step(self, batch, batch_idx):
xs, ys = batch
logits, loss = self.loss(xs, ys)
preds = torch.argmax(logits, 1)
self.valid_acc(preds, ys)
self.log("valid/loss_epoch", loss)
self.log('valid/acc_epoch', self.valid_acc)
return logits
def validation_epoch_end(self, validation_step_outputs):
dummy_input = torch.zeros((self.batch_size, *(3, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)),
device=self.device)
model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
torch.onnx.export(self, dummy_input, 'latest_run' + model_filename, opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}}
)
wandb.save(model_filename)
flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
self.logger.experiment.log(
{"valid/logits": wandb.Histogram(flattened_logits.to("cpu")),
"global_step": self.global_step})
def get_img_class_map():
with open('index_to_name.json') as f:
img_class_map = json.load(f)
return img_class_map
def get_feature_extractor():
backbone = create_timm_model('resnet50d', pretrained=True)
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
return nn.Sequential(*layers), num_filters