File size: 2,583 Bytes
88b279a
 
 
 
cefcefa
 
88b279a
 
cefcefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88b279a
 
cefcefa
 
 
 
 
 
88b279a
 
cefcefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import torch
from PIL import Image
from typing import List
from waifu_scorer.mlp import MLP
import clip

# Load the pre-trained model
model_path = "./1024_MLP_best-MSE4.1636_ep75.pth"
device = "cpu"
dtype = torch.float32
s = torch.load(model_path, map_location=device)
model = MLP(input_size=768)
model.load_state_dict(s)
model.to(device=device, dtype=dtype)

model2, preprocess = clip.load("ViT-L/14", device=device)

def normalized(a: torch.Tensor, order=2, dim=-1):
    l2 = a.norm(order, dim, keepdim=True)
    l2[l2 == 0] = 1
    return a / l2

@torch.no_grad()
def encode_images(images: List[Image.Image], model2, preprocess, device='cpu') -> torch.Tensor:
    if not isinstance(images, list):
        images = [images]
    image_tensors = [preprocess(img).unsqueeze(0) for img in images]
    image_batch = torch.cat(image_tensors).to(device)
    image_features = model2.encode_image(image_batch)
    im_emb_arr = normalized(image_features).cpu().float()
    return im_emb_arr

@torch.no_grad()
def predict(inputs: List[Image.Image]) -> float:
    images = encode_images(inputs, model2, preprocess, device=device).to(device=device, dtype=dtype)
    predictions = model(images)
    scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
    return scores


from waifu_scorer.predict import WaifuScorer, load_model
scorer = WaifuScorer(
    model_path=model_path,
    model_type="mlp",
    device=device,
)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(
                label='Image',
                type='pil',
                height=512,
                sources=['upload', 'clipboard'],
            )
        with gr.Column():
            with gr.Row():
                model_path = gr.Textbox(
                    label='Model Path',
                    value=model_path,
                    placeholder='Path or URL to the model file',
                    # interactive=not fix_model_path,
                )
            with gr.Row():
                score = gr.Number(
                    label='Score',
                )

    def change_model(model_path):
        scorer.mlp = load_model(model_path, model_type="mlp", device=device)
        print(f"Model changed to `{model_path}`")
        return gr.update()

    model_path.submit(
        fn=change_model,
        inputs=model_path,
        outputs=model_path,
    )

    image.change(
        fn=lambda image: predict([image]*2)[0] if image is not None else None,
        inputs=image,
        outputs=score,
    )

demo.launch()