Eugeoter commited on
Commit
b18a65a
1 Parent(s): ffc837d

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/1.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/2.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modules.predict import WaifuScorer
3
+
4
+
5
+ def ui():
6
+ scorer = WaifuScorer()
7
+
8
+ def predict(img):
9
+ return scorer(img)
10
+
11
+ interface = gr.Interface(title='Waifu Scorer', description='A model that scores an anime illustration (0 ~ 10).',
12
+ fn=predict, inputs=gr.Image(sources='upload', type='pil', height=512), outputs=gr.Number(precision=3),
13
+ allow_flagging='never', examples=[['./examples/1.png'], ['./examples/2.png']]
14
+ )
15
+
16
+ return interface
17
+
18
+
19
+ if __name__ == '__main__':
20
+ ui().launch()
examples/1.png ADDED

Git LFS Details

  • SHA256: cfd16f47c1f161fb55116f864ac427549a1e1c01e124f418b3f342b92828aaef
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
examples/2.png ADDED

Git LFS Details

  • SHA256: f273c062c6310b6ceac736d35f626a0e9804185b64bae0da777d5a13c1a81a06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .predict import WaifuScorer
modules/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import pytorch_lightning as pl
3
+
4
+
5
+ class MLP(pl.LightningModule):
6
+ def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True):
7
+ super().__init__()
8
+ self.input_size = input_size
9
+ self.xcol = xcol
10
+ self.ycol = ycol
11
+ self.layers = nn.Sequential(
12
+ nn.Linear(self.input_size, 2048),
13
+ nn.ReLU(),
14
+ nn.BatchNorm1d(2048) if batch_norm else nn.Identity(),
15
+ nn.Dropout(0.3),
16
+ nn.Linear(2048, 512),
17
+ nn.ReLU(),
18
+ nn.BatchNorm1d(512) if batch_norm else nn.Identity(),
19
+ nn.Dropout(0.3),
20
+ nn.Linear(512, 256),
21
+ nn.ReLU(),
22
+ nn.BatchNorm1d(256) if batch_norm else nn.Identity(),
23
+ nn.Dropout(0.2),
24
+ nn.Linear(256, 128),
25
+ nn.ReLU(),
26
+ nn.BatchNorm1d(128) if batch_norm else nn.Identity(),
27
+ nn.Dropout(0.1),
28
+ nn.Linear(128, 32),
29
+ nn.ReLU(),
30
+ nn.Linear(32, 1)
31
+ )
32
+
33
+ def forward(self, x):
34
+ return self.layers(x)
modules/predict.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ import os
4
+ import time
5
+ from PIL import Image
6
+ from typing import List
7
+ from .mlp import MLP
8
+ from .utils import download_from_url
9
+
10
+ MLP_MODEL_URL = "https://huggingface.co/Eugeoter/waifu-scorer/waifu-scorer-v1-large.pth"
11
+
12
+
13
+ class WaifuScorer:
14
+ def __init__(self, model_path: str = None, device: str = 'cuda', verbose=False):
15
+ self.verbose = verbose
16
+
17
+ if self.verbose:
18
+ tic = time.time()
19
+ print(f"loading pretrained model from `{model_path}`")
20
+
21
+ if model_path is None or not os.path.isfile(model_path):
22
+ model_path = download_from_url(MLP_MODEL_URL)
23
+
24
+ if device == 'cuda' and not torch.cuda.is_available():
25
+ device = 'cpu'
26
+ print("CUDA is not available, using CPU instead")
27
+
28
+ self.mlp = load_model(model_path, input_size=768, device=device)
29
+ self.model2, self.preprocess = load_clip_models("ViT-L/14", device=device)
30
+ self.device = self.mlp.device
31
+ self.dtype = self.mlp.dtype
32
+
33
+ self.mlp.eval()
34
+
35
+ if self.verbose:
36
+ toc = time.time()
37
+ print(f"model loaded: time_cost={toc-tic:.2f} | device={self.device} | dtype={self.dtype}")
38
+
39
+ @torch.no_grad()
40
+ def __call__(self, images: List[Image.Image]) -> List[float]:
41
+ if isinstance(images, Image.Image):
42
+ images = [images]
43
+ n = len(images)
44
+ if n == 1:
45
+ images = images*2 # batch norm
46
+ images = encode_images(images, self.model2, self.preprocess, device=self.device).to(device=self.device, dtype=self.dtype)
47
+ predictions = self.mlp(images)
48
+ scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
49
+ if n == 1:
50
+ scores = scores[0]
51
+ return scores
52
+
53
+
54
+ def load_clip_models(name: str = "ViT-L/14", device='cuda'):
55
+ model2, preprocess = clip.load(name, device=device) # RN50x64
56
+ return model2, preprocess
57
+
58
+
59
+ def load_model(model_path: str = None, input_size=768, device: str = 'cuda', dtype=None):
60
+ model = MLP(input_size=input_size)
61
+ if model_path:
62
+ s = torch.load(model_path, map_location=device)
63
+ model.load_state_dict(s)
64
+ model.to(device)
65
+ if dtype:
66
+ model = model.to(dtype=dtype)
67
+ return model
68
+
69
+
70
+ def normalized(a: torch.Tensor, order=2, dim=-1):
71
+ l2 = a.norm(order, dim, keepdim=True)
72
+ l2[l2 == 0] = 1
73
+ return a / l2
74
+
75
+
76
+ @torch.no_grad()
77
+ def encode_images(images: List[Image.Image], model2, preprocess, device='cuda') -> torch.Tensor:
78
+ if isinstance(images, Image.Image):
79
+ images = [images]
80
+ image_tensors = [preprocess(img).unsqueeze(0) for img in images]
81
+ image_batch = torch.cat(image_tensors).to(device)
82
+ image_features = model2.encode_image(image_batch)
83
+ im_emb_arr = normalized(image_features).cpu().float()
84
+ return im_emb_arr
modules/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+
4
+ def download_from_url(url, cache_dir=None, verbose=True):
5
+ split = url.split("/")
6
+ username, repo_id, model_name = split[-3], split[-2], split[-1]
7
+ if verbose:
8
+ print(f"downloading: {username}/{repo_id}/{model_name}")
9
+ model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)
10
+ return model_path
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ clip
3
+ pytorch-lightning
4
+ pillow
5
+ huggingface-hub