pjdevelop's picture
commit files to HF hub
e444fbb
raw
history blame
781 Bytes
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms as T
import joblib
# Load models
dinov2_vits14 = torch.load('dinov2_vits14.pth', map_location=torch.device('cpu'))
clf = joblib.load('svm_model.joblib')
# Transform for input image
transform_image = T.Compose([T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])])
def predict(image):
image = Image.fromarray(image)
transformed_img = transform_image(image)[:3].unsqueeze(0)
with torch.no_grad():
embedding = dinov2_vits14(transformed_img)
prediction = clf.predict(np.array(embedding[0].cpu()).reshape(1, -1))
return prediction[0]
iface = gr.Interface(fn=predict, inputs="image", outputs="text")
iface.launch()