LUWA / run_gradio.py
DanielXu0208's picture
Initial commit
785ef2b
raw
history blame
1.56 kB
import gradio as gr
import torch
import torchvision
from utils.experiment_utils import get_model
# 加载DINOv2模型
def load_model():
class Args:
model = 'DINOv2'
pretrained = 'pretrained'
frozen = 'unfrozen'
args = Args()
model = get_model(args)
model.eval()
return model
model = load_model()
# 预测函数,返回每个类别的概率
def predict(image):
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
# 类别名称列表
class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
# 将类别和对应的概率配对
results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
return results
# 创建Gradio界面
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=len(["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"])),
title="LUWA DINOv2 Prediction",
description="Upload an image to get the probabilities for each class using the DINOv2 model."
)
if __name__ == "__main__":
interface.launch(share=True)