npv2k1's picture
update
06c8a6d verified
raw
history blame
1.48 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import functional as F
from src.models.model import ShapeClassifier # Import your model class
from torchvision import transforms
import os
from src.data.transform import data_transform
def classify_drawing(drawing_image):
# return null if no drawing is provided
if drawing_image is None:
return None
# Load the trained model
num_classes = 3 # Set the number of classes
# Initialize your model class
model = ShapeClassifier(num_classes=num_classes)
model.load_state_dict(torch.load('results/models/model.pth', map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
# Convert the drawing to a grayscale image
drawing = np.array(drawing_image)
drawing_tensor = data_transform(Image.fromarray(drawing))
# save all the drawing to a folder draw with index
# Image.fromarray(drawing).save(f'draw/{len(os.listdir("draw"))}.png')
# Perform inference
with torch.no_grad():
output = model(drawing_tensor)
shape_classes = ["Circle", "Square", "Triangle"]
predicted_class = torch.argmax(output, dim=1).item()
predicted_label = shape_classes[predicted_class]
return predicted_label
iface = gr.Interface(
fn=classify_drawing,
inputs=gr.Image(type="pil"), # Use Sketchpad as input
outputs="text",
live=True,
)
iface.launch(server_port=7860)