Mini-Vision-V2 / demo.py
LWWZH's picture
Upload Mini-Vision-V2
5b6d90c verified
from model import MiniVisionV2
import torch
import torchvision
import gradio as gr
import webbrowser
minivisionv2 = torch.load("Mini-Vision-V2.pth", weights_only=False)
minivisionv2.eval()
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor()])
def classifier(img):
input = transform(img["composite"])
input = 1.0 - input
tensor = input.unsqueeze(0)
with torch.no_grad():
output = minivisionv2(tensor)
output = torch.softmax(output, dim=1)
result = {}
for i in range(10):
result[str(i)] = output[0][i].item()
return result
demo = gr.Interface(fn=classifier,
inputs=gr.Sketchpad(height=280, width=280, image_mode="L", label="Sketch Pad", type="pil"),
outputs=gr.Label(label="Classifying Results"),
title="Mini-Vision-V2",
description="Write number 0-9 in the sketch pad below"
)
if __name__ == '__main__':
webbrowser.open("http://127.0.0.1:7860")
demo.launch(share=True)