Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision.transforms import functional as F | |
from src.models.model import ShapeClassifier # Import your model class | |
from torchvision import transforms | |
import os | |
from src.data.transform import data_transform | |
def classify_drawing(drawing_image): | |
# return null if no drawing is provided | |
if drawing_image is None: | |
return None | |
# Load the trained model | |
num_classes = 3 # Set the number of classes | |
# Initialize your model class | |
model = ShapeClassifier(num_classes=num_classes) | |
model.load_state_dict(torch.load('results/models/model.pth', map_location=torch.device('cpu'))) | |
model.eval() # Set the model to evaluation mode | |
# Convert the drawing to a grayscale image | |
drawing = np.array(drawing_image) | |
drawing_tensor = data_transform(Image.fromarray(drawing)) | |
# save all the drawing to a folder draw with index | |
# Image.fromarray(drawing).save(f'draw/{len(os.listdir("draw"))}.png') | |
# Perform inference | |
with torch.no_grad(): | |
output = model(drawing_tensor) | |
shape_classes = ["Circle", "Square", "Triangle"] | |
predicted_class = torch.argmax(output, dim=1).item() | |
predicted_label = shape_classes[predicted_class] | |
return predicted_label | |
iface = gr.Interface( | |
fn=classify_drawing, | |
inputs=gr.Image(type="pil"), # Use Sketchpad as input | |
outputs="text", | |
live=True, | |
) | |
iface.launch(server_port=7860) | |