Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchvision | |
from utils.experiment_utils import get_model | |
# 加载DINOv2模型 | |
def load_model(): | |
class Args: | |
model = 'DINOv2' | |
pretrained = 'pretrained' | |
frozen = 'unfrozen' | |
args = Args() | |
model = get_model(args) | |
model.eval() | |
return model | |
model = load_model() | |
# 预测函数,返回每个类别的概率 | |
def predict(image): | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize((224, 224)), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model(image) | |
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist() | |
# 类别名称列表 | |
class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"] | |
# 将类别和对应的概率配对 | |
results = {class_names[i]: probabilities[i] for i in range(len(class_names))} | |
return results | |
# 创建Gradio界面 | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=len(["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"])), | |
title="LUWA DINOv2 Prediction", | |
description="Upload an image to get the probabilities for each class using the DINOv2 model." | |
) | |
if __name__ == "__main__": | |
interface.launch(share=True) |