pjdevelop's picture
Create app.py
0af6300 verified
raw
history blame
781 Bytes
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms as T
import joblib
# Load models
dinov2_vits14 = torch.load('dinov2_vits14.pth', map_location=torch.device('cpu'))
clf = joblib.load('svm_model.joblib')
# Transform for input image
transform_image = T.Compose([T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])])
def predict(image):
image = Image.fromarray(image)
transformed_img = transform_image(image)[:3].unsqueeze(0)
with torch.no_grad():
embedding = dinov2_vits14(transformed_img)
prediction = clf.predict(np.array(embedding[0].cpu()).reshape(1, -1))
return prediction[0]
iface = gr.Interface(fn=predict, inputs="image", outputs="text")
iface.launch()