import gradio as gr import torch from PIL import Image from typing import List from waifu_scorer.mlp import MLP import clip # Load the pre-trained model model_path = "./3072_MLP_best-MSE1.6230_ep23.pth" device = "cpu" dtype = torch.float32 s = torch.load(model_path, map_location=device) model = MLP(input_size=768) model.load_state_dict(s) model.to(device=device, dtype=dtype) model2, preprocess = clip.load("ViT-L/14", device=device) def normalized(a: torch.Tensor, order=2, dim=-1): l2 = a.norm(order, dim, keepdim=True) l2[l2 == 0] = 1 return a / l2 @torch.no_grad() def encode_images(images: List[Image.Image], model2, preprocess, device='cpu') -> torch.Tensor: if not isinstance(images, list): images = [images] image_tensors = [preprocess(img).unsqueeze(0) for img in images] image_batch = torch.cat(image_tensors).to(device) image_features = model2.encode_image(image_batch) im_emb_arr = normalized(image_features).cpu().float() return im_emb_arr @torch.no_grad() def predict(inputs: List[Image.Image]) -> float: images = encode_images(inputs, model2, preprocess, device=device).to(device=device, dtype=dtype) predictions = model(images) scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() return scores from waifu_scorer.predict import WaifuScorer, load_model scorer = WaifuScorer( model_path=model_path, model_type="mlp", device=device, ) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image( label='Image', type='pil', height=512, sources=['upload', 'clipboard'], ) with gr.Column(): with gr.Row(): model_path = gr.Textbox( label='Model Path', value=model_path, placeholder='Path or URL to the model file', interactive=True, ) with gr.Row(): score = gr.Number( label='Score', ) def change_model(model_path): scorer.mlp = load_model(model_path, model_type="mlp", device=device) print(f"Model changed to `{model_path}`") return gr.update() model_path.submit( fn=change_model, inputs=model_path, outputs=model_path, ) image.change( fn=lambda image: predict([image]*2)[0] if image is not None else None, inputs=image, outputs=score, ) demo.launch()