File size: 1,559 Bytes
785ef2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import torch
import torchvision
from utils.experiment_utils import get_model


# 加载DINOv2模型
def load_model():
    class Args:
        model = 'DINOv2'
        pretrained = 'pretrained'
        frozen = 'unfrozen'

    args = Args()
    model = get_model(args)
    model.eval()
    return model


model = load_model()


# 预测函数,返回每个类别的概率
def predict(image):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        output = model(image)
        probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()

    # 类别名称列表
    class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]

    # 将类别和对应的概率配对
    results = {class_names[i]: probabilities[i] for i in range(len(class_names))}

    return results


# 创建Gradio界面
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=len(["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"])),
    title="LUWA DINOv2 Prediction",
    description="Upload an image to get the probabilities for each class using the DINOv2 model."
)

if __name__ == "__main__":
    interface.launch(share=True)