MUTED64
Add app and pth
88b279a
raw
history blame
No virus
1.1 kB
import gradio as gr
import torch
from PIL import Image
from torchvision.transforms import functional as F
from typing import List
from transformers import CLIPModel, CLIPProcessor
# Load the pre-trained model
model_path = "1024_MLP_best-MSE4.1636_ep75.pth"
model = torch.load(model_path)
model.eval()
# Load the CLIP model and processor
clip_model = CLIPModel.from_pretrained("ViT-L/14")
clip_processor = CLIPProcessor.from_pretrained("ViT-L/14")
# Define the prediction function
def predict(images: List[Image.Image]) -> float:
image_tensors = [F.to_tensor(img) for img in images]
inputs = clip_processor(images=image_tensors, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(inputs.pixel_values)
scores = outputs.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
return scores
# Define the Gradio interface
iface = gr.Interface(
fn=predict,
inputs="image",
outputs="number",
title="Kemono Aesthetic Scorer",
description="Predict the score of a kemono based on aesthetic features.",
)
# Run the Gradio interface
iface.launch()