Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from typing import List | |
from waifu_scorer.mlp import MLP | |
import clip | |
# Load the pre-trained model | |
model_path = "./2048_MLP_best-MSE2.7738_ep110.pth" | |
device = "cpu" | |
dtype = torch.float32 | |
s = torch.load(model_path, map_location=device) | |
model = MLP(input_size=768) | |
model.load_state_dict(s) | |
model.to(device=device, dtype=dtype) | |
model2, preprocess = clip.load("ViT-L/14", device=device) | |
def normalized(a: torch.Tensor, order=2, dim=-1): | |
l2 = a.norm(order, dim, keepdim=True) | |
l2[l2 == 0] = 1 | |
return a / l2 | |
def encode_images(images: List[Image.Image], model2, preprocess, device='cpu') -> torch.Tensor: | |
if not isinstance(images, list): | |
images = [images] | |
image_tensors = [preprocess(img).unsqueeze(0) for img in images] | |
image_batch = torch.cat(image_tensors).to(device) | |
image_features = model2.encode_image(image_batch) | |
im_emb_arr = normalized(image_features).cpu().float() | |
return im_emb_arr | |
def predict(inputs: List[Image.Image]) -> float: | |
images = encode_images(inputs, model2, preprocess, device=device).to(device=device, dtype=dtype) | |
predictions = model(images) | |
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() | |
return scores | |
from waifu_scorer.predict import WaifuScorer, load_model | |
scorer = WaifuScorer( | |
model_path=model_path, | |
model_type="mlp", | |
device=device, | |
) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image( | |
label='Image', | |
type='pil', | |
height=512, | |
sources=['upload', 'clipboard'], | |
) | |
with gr.Column(): | |
with gr.Row(): | |
model_path = gr.Textbox( | |
label='Model Path', | |
value=model_path, | |
placeholder='Path or URL to the model file', | |
interactive=True, | |
) | |
with gr.Row(): | |
score = gr.Number( | |
label='Score', | |
) | |
def change_model(model_path): | |
scorer.mlp = load_model(model_path, model_type="mlp", device=device) | |
print(f"Model changed to `{model_path}`") | |
return gr.update() | |
model_path.submit( | |
fn=change_model, | |
inputs=model_path, | |
outputs=model_path, | |
) | |
image.change( | |
fn=lambda image: predict([image]*2)[0] if image is not None else None, | |
inputs=image, | |
outputs=score, | |
) | |
demo.launch() |