File size: 781 Bytes
e444fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms as T
import joblib

# Load models
dinov2_vits14 = torch.load('dinov2_vits14.pth', map_location=torch.device('cpu'))
clf = joblib.load('svm_model.joblib')

# Transform for input image
transform_image = T.Compose([T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])])

def predict(image):
    image = Image.fromarray(image)
    transformed_img = transform_image(image)[:3].unsqueeze(0)
    with torch.no_grad():
        embedding = dinov2_vits14(transformed_img)
        prediction = clf.predict(np.array(embedding[0].cpu()).reshape(1, -1))
    return prediction[0]

iface = gr.Interface(fn=predict, inputs="image", outputs="text")
iface.launch()