MUTED64
update model
7ca56ed
raw
history blame contribute delete
No virus
2.57 kB
import gradio as gr
import torch
from PIL import Image
from typing import List
from waifu_scorer.mlp import MLP
import clip
# Load the pre-trained model
model_path = "./3072_MLP_best-MSE1.6230_ep23.pth"
device = "cpu"
dtype = torch.float32
s = torch.load(model_path, map_location=device)
model = MLP(input_size=768)
model.load_state_dict(s)
model.to(device=device, dtype=dtype)
model2, preprocess = clip.load("ViT-L/14", device=device)
def normalized(a: torch.Tensor, order=2, dim=-1):
l2 = a.norm(order, dim, keepdim=True)
l2[l2 == 0] = 1
return a / l2
@torch.no_grad()
def encode_images(images: List[Image.Image], model2, preprocess, device='cpu') -> torch.Tensor:
if not isinstance(images, list):
images = [images]
image_tensors = [preprocess(img).unsqueeze(0) for img in images]
image_batch = torch.cat(image_tensors).to(device)
image_features = model2.encode_image(image_batch)
im_emb_arr = normalized(image_features).cpu().float()
return im_emb_arr
@torch.no_grad()
def predict(inputs: List[Image.Image]) -> float:
images = encode_images(inputs, model2, preprocess, device=device).to(device=device, dtype=dtype)
predictions = model(images)
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
return scores
from waifu_scorer.predict import WaifuScorer, load_model
scorer = WaifuScorer(
model_path=model_path,
model_type="mlp",
device=device,
)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(
label='Image',
type='pil',
height=512,
sources=['upload', 'clipboard'],
)
with gr.Column():
with gr.Row():
model_path = gr.Textbox(
label='Model Path',
value=model_path,
placeholder='Path or URL to the model file',
interactive=True,
)
with gr.Row():
score = gr.Number(
label='Score',
)
def change_model(model_path):
scorer.mlp = load_model(model_path, model_type="mlp", device=device)
print(f"Model changed to `{model_path}`")
return gr.update()
model_path.submit(
fn=change_model,
inputs=model_path,
outputs=model_path,
)
image.change(
fn=lambda image: predict([image]*2)[0] if image is not None else None,
inputs=image,
outputs=score,
)
demo.launch()