Spaces:
Running
Running
#!/usr/bin/env python | |
# coding: utf-8 | |
import torch | |
from torch import nn | |
import numpy as np | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
from src.trainer import LitTrainer | |
from src.models import CNN | |
from src.dataset import DatasetMNIST, download_mnist | |
def load_pl_net(path="checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt"): | |
pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10)) | |
return pl_net | |
def load_torch_net(path="checkpoints/pytorch/version_0.pt"): | |
state_dict = torch.load(path) | |
net = CNN(1, 10) | |
net.load_state_dict(state_dict) | |
return net | |
def get_sequence(model): | |
fig = make_subplots(rows=2, cols=5) | |
i, j = 0, np.random.randint(0, 30000) | |
while i < 10: | |
x, y = dataset[j] | |
predicted, p = predict(x, model) | |
if predicted == i and p > 0.95: | |
img = np.flip(np.array(x.reshape(28, 28)), 0) | |
fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i % 5+1) | |
i += 1 | |
j += 1 | |
return fig | |
def predict(x, model, device="cuda"): | |
y_pred = model(x.to(device)).detach().cpu() | |
predicted = int(np.argmax(y_pred)) | |
p = torch.max(nn.functional.softmax(y_pred, dim=0)) | |
return predicted, p | |
def predict_interval(x, model, device="cuda"): | |
y_pred = model(x.to(device)) | |
print(y_pred) | |
predicted = np.argsort(y_pred.cpu().detach().numpy()) | |
p = nn.functional.softmax(y_pred, dim=0) | |
return {int(i): float(p[i]) for i in predicted} | |
if __name__ == "__main__": | |
mnist = download_mnist("downloads/mnist/") | |
dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"]) | |
print("PyTorch Lightning Network") | |
get_sequence(load_pl_net().to("cuda")).show() | |
print("Manual Network") | |
get_sequence(load_torch_net().to("cuda")).show() | |