MUTED64 commited on
Commit
cefcefa
1 Parent(s): 0668dff

change scorer

Browse files
__init__.py ADDED
File without changes
api.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from waifu_scorer.ui import launch, parse_args
2
+
3
+ if __name__ == '__main__':
4
+ args = parse_args()
5
+ launch(args)
app.py CHANGED
@@ -1,36 +1,88 @@
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
4
- from torchvision.transforms import functional as F
5
  from typing import List
6
- from transformers import CLIPModel, CLIPProcessor
 
7
 
8
  # Load the pre-trained model
9
- model_path = "1024_MLP_best-MSE4.1636_ep75.pth"
10
- model = torch.load(model_path)
11
- model.eval()
12
-
13
- # Load the CLIP model and processor
14
- clip_model = CLIPModel.from_pretrained("ViT-L/14")
15
- clip_processor = CLIPProcessor.from_pretrained("ViT-L/14")
16
-
17
- # Define the prediction function
18
- def predict(images: List[Image.Image]) -> float:
19
- image_tensors = [F.to_tensor(img) for img in images]
20
- inputs = clip_processor(images=image_tensors, return_tensors="pt", padding=True)
21
- with torch.no_grad():
22
- outputs = model(inputs.pixel_values)
23
- scores = outputs.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return scores
25
 
26
- # Define the Gradio interface
27
- iface = gr.Interface(
28
- fn=predict,
29
- inputs="image",
30
- outputs="number",
31
- title="Kemono Aesthetic Scorer",
32
- description="Predict the score of a kemono based on aesthetic features.",
33
  )
34
 
35
- # Run the Gradio interface
36
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
 
4
  from typing import List
5
+ from waifu_scorer.mlp import MLP
6
+ import clip
7
 
8
  # Load the pre-trained model
9
+ model_path = "./1024_MLP_best-MSE4.1636_ep75.pth"
10
+ device = "cpu"
11
+ dtype = torch.float32
12
+ s = torch.load(model_path, map_location=device)
13
+ model = MLP(input_size=768)
14
+ model.load_state_dict(s)
15
+ model.to(device=device, dtype=dtype)
16
+
17
+ model2, preprocess = clip.load("ViT-L/14", device=device)
18
+
19
+ def normalized(a: torch.Tensor, order=2, dim=-1):
20
+ l2 = a.norm(order, dim, keepdim=True)
21
+ l2[l2 == 0] = 1
22
+ return a / l2
23
+
24
+ @torch.no_grad()
25
+ def encode_images(images: List[Image.Image], model2, preprocess, device='cpu') -> torch.Tensor:
26
+ if not isinstance(images, list):
27
+ images = [images]
28
+ image_tensors = [preprocess(img).unsqueeze(0) for img in images]
29
+ image_batch = torch.cat(image_tensors).to(device)
30
+ image_features = model2.encode_image(image_batch)
31
+ im_emb_arr = normalized(image_features).cpu().float()
32
+ return im_emb_arr
33
+
34
+ @torch.no_grad()
35
+ def predict(inputs: List[Image.Image]) -> float:
36
+ images = encode_images(inputs, model2, preprocess, device=device).to(device=device, dtype=dtype)
37
+ predictions = model(images)
38
+ scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
39
  return scores
40
 
41
+
42
+ from waifu_scorer.predict import WaifuScorer, load_model
43
+ scorer = WaifuScorer(
44
+ model_path=model_path,
45
+ model_type="mlp",
46
+ device=device,
 
47
  )
48
 
49
+ with gr.Blocks() as demo:
50
+ with gr.Row():
51
+ with gr.Column():
52
+ image = gr.Image(
53
+ label='Image',
54
+ type='pil',
55
+ height=512,
56
+ sources=['upload', 'clipboard'],
57
+ )
58
+ with gr.Column():
59
+ with gr.Row():
60
+ model_path = gr.Textbox(
61
+ label='Model Path',
62
+ value=model_path,
63
+ placeholder='Path or URL to the model file',
64
+ # interactive=not fix_model_path,
65
+ )
66
+ with gr.Row():
67
+ score = gr.Number(
68
+ label='Score',
69
+ )
70
+
71
+ def change_model(model_path):
72
+ scorer.mlp = load_model(model_path, model_type="mlp", device=device)
73
+ print(f"Model changed to `{model_path}`")
74
+ return gr.update()
75
+
76
+ model_path.submit(
77
+ fn=change_model,
78
+ inputs=model_path,
79
+ outputs=model_path,
80
+ )
81
+
82
+ image.change(
83
+ fn=lambda image: predict([image]*2)[0] if image is not None else None,
84
+ inputs=image,
85
+ outputs=score,
86
+ )
87
+
88
+ demo.launch()
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torch
3
  Pillow
4
  torchvision
5
  typing
6
- transformers
 
 
3
  Pillow
4
  torchvision
5
  typing
6
+ pytorch_lightning
7
+ clip
setup.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+ with open('./requirements.txt') as f:
3
+ requirements = f.read().splitlines()
4
+
5
+ for i, req in enumerate(requirements):
6
+ if req.startswith('git+'):
7
+ package_name = req.split('/')[-1].split('.')[0] # Extract package name from URL
8
+ requirements[i] = f"{package_name} @ {req}"
9
+
10
+ setup(
11
+ name='waifu-scorer',
12
+ version='0.1',
13
+ packages=find_packages(),
14
+ include_package_data=True,
15
+ description='Image caption tools',
16
+ long_description='',
17
+ author='euge',
18
+ author_email='1507064225@qq.com',
19
+ url='https://github.com/Eugeoter/waifu-scorer',
20
+ install_requires=requirements,
21
+ classifiers=[
22
+ 'Development Status :: 3 - Alpha',
23
+ 'Intended Audience :: Developers',
24
+ 'License :: OSI Approved :: MIT License',
25
+ 'Programming Language :: Python :: 3',
26
+ 'Programming Language :: Python :: 3.7',
27
+ ],
28
+ )
waifu_scorer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .predict import WaifuScorer
waifu_scorer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (233 Bytes). View file
 
waifu_scorer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (240 Bytes). View file
 
waifu_scorer/__pycache__/mlp.cpython-312.pyc ADDED
Binary file (5.52 kB). View file
 
waifu_scorer/__pycache__/predict.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
waifu_scorer/__pycache__/predict.cpython-312.pyc ADDED
Binary file (4.98 kB). View file
 
waifu_scorer/__pycache__/train.cpython-312.pyc ADDED
Binary file (14.2 kB). View file
 
waifu_scorer/__pycache__/train_utils.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
waifu_scorer/__pycache__/ui.cpython-312.pyc ADDED
Binary file (3.89 kB). View file
 
waifu_scorer/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.75 kB). View file
 
waifu_scorer/mlp.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import pytorch_lightning as pl
4
+
5
+
6
+ class MLP(pl.LightningModule):
7
+ def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True):
8
+ super().__init__()
9
+ self.input_size = input_size
10
+ self.xcol = xcol
11
+ self.ycol = ycol
12
+ # self.layers = nn.Sequential(
13
+ # nn.Linear(self.input_size, 2048),
14
+ # nn.ReLU(),
15
+ # nn.BatchNorm1d(2048),
16
+ # nn.Dropout(0.4),
17
+
18
+ # nn.Linear(2048, 512),
19
+ # nn.ReLU(),
20
+ # nn.BatchNorm1d(512),
21
+ # nn.Dropout(0.3),
22
+
23
+ # nn.Linear(512, 256),
24
+ # nn.ReLU(),
25
+ # nn.BatchNorm1d(256),
26
+ # nn.Dropout(0.2),
27
+
28
+ # nn.Linear(256, 128),
29
+ # nn.ReLU(),
30
+ # nn.BatchNorm1d(128),
31
+ # nn.Dropout(0.1),
32
+
33
+ # nn.Linear(128, 32),
34
+ # nn.ReLU(),
35
+ # nn.Linear(32, 1)
36
+ # )
37
+ self.layers = nn.Sequential(
38
+ nn.Linear(self.input_size, 1024),
39
+ # nn.ReLU(),
40
+ nn.Dropout(0.2),
41
+ nn.Linear(1024, 128),
42
+ # nn.ReLU(),
43
+ nn.Dropout(0.2),
44
+ nn.Linear(128, 64),
45
+ # nn.ReLU(),
46
+ nn.Dropout(0.1),
47
+
48
+ nn.Linear(64, 16),
49
+ # nn.ReLU(),
50
+
51
+ nn.Linear(16, 1)
52
+ )
53
+
54
+ def forward(self, x):
55
+ return self.layers(x)
56
+
57
+ def training_step(self, batch, batch_idx):
58
+ x = batch[self.xcol]
59
+ y = batch[self.ycol].reshape(-1, 1)
60
+ x_hat = self.layers(x)
61
+ loss = F.mse_loss(x_hat, y)
62
+ return loss
63
+
64
+ def validation_step(self, batch, batch_idx):
65
+ x = batch[self.xcol]
66
+ y = batch[self.ycol].reshape(-1, 1)
67
+ x_hat = self.layers(x)
68
+ loss = F.mse_loss(x_hat, y)
69
+ return loss
70
+
71
+ # def configure_optimizers(self):
72
+ # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
73
+ # return optimizer
74
+
75
+
76
+ class ResidualBlock(nn.Module):
77
+ def __init__(self, input_size, output_size, batch_norm=True, dropout_rate=0.0):
78
+ super(ResidualBlock, self).__init__()
79
+ self.linear = nn.Linear(input_size, output_size)
80
+ self.relu = nn.ReLU()
81
+ self.batch_norm = nn.BatchNorm1d(output_size) if batch_norm else nn.Identity()
82
+ self.dropout = nn.Dropout(dropout_rate)
83
+ self.adjust_dims = nn.Linear(input_size, output_size) if input_size != output_size else nn.Identity()
84
+
85
+ def forward(self, x):
86
+ identity = self.adjust_dims(x)
87
+ out = self.linear(x)
88
+ out = self.relu(out)
89
+ out = self.batch_norm(out)
90
+ out = self.dropout(out)
91
+ out += identity
92
+ out = self.relu(out)
93
+ return out
94
+
95
+
96
+ class ResMLP(pl.LightningModule):
97
+ def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True):
98
+ super().__init__()
99
+ self.input_size = input_size
100
+ self.xcol = xcol
101
+ self.ycol = ycol
102
+ self.layers = nn.Sequential(
103
+ ResidualBlock(input_size, 2048, batch_norm, dropout_rate=0.3),
104
+ ResidualBlock(2048, 512, batch_norm, dropout_rate=0.3),
105
+ ResidualBlock(512, 256, batch_norm, dropout_rate=0.2),
106
+ ResidualBlock(256, 128, batch_norm, dropout_rate=0.1),
107
+ nn.Linear(128, 32),
108
+ nn.ReLU(),
109
+ nn.Linear(32, 1)
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.layers(x)
114
+
115
+ def training_step(self, batch, batch_idx):
116
+ x = batch[self.xcol]
117
+ y = batch[self.ycol].reshape(-1, 1)
118
+ x_hat = self.layers(x)
119
+ loss = F.mse_loss(x_hat, y)
120
+ return loss
121
+
122
+ def validation_step(self, batch, batch_idx):
123
+ x = batch[self.xcol]
124
+ y = batch[self.ycol].reshape(-1, 1)
125
+ x_hat = self.layers(x)
126
+ loss = F.mse_loss(x_hat, y)
127
+ return loss
waifu_scorer/predict.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ import os
4
+ from PIL import Image
5
+ from typing import List
6
+ from .utils import get_model_cls
7
+
8
+ WAIFU_FILTER_V1_MODEL_REPO = 'Eugeoter/waifu-filter-v1/waifu-filter-v1.pth'
9
+
10
+
11
+ def download_from_url(url):
12
+ from huggingface_hub import hf_hub_download
13
+ split = url.split("/")
14
+ username, repo_id, model_name = split[-3], split[-2], split[-1]
15
+ model_path = hf_hub_download(f"{username}/{repo_id}", model_name)
16
+ return model_path
17
+
18
+
19
+ def load_model(model_path: str = None, model_type='mlp', input_size=768, device: str = 'cuda', dtype=torch.float32):
20
+ model_cls = get_model_cls(model_type)
21
+ model = model_cls(input_size=input_size)
22
+ if not os.path.isfile(model_path):
23
+ model_path = download_from_url(model_path)
24
+ s = torch.load(model_path, map_location=device)
25
+ model.load_state_dict(s)
26
+ model.to(device=device, dtype=dtype)
27
+ return model
28
+
29
+
30
+ def normalized(a: torch.Tensor, order=2, dim=-1):
31
+ l2 = a.norm(order, dim, keepdim=True)
32
+ l2[l2 == 0] = 1
33
+ return a / l2
34
+
35
+
36
+ @torch.no_grad()
37
+ def encode_images(images: List[Image.Image], model2, preprocess, device='cuda') -> torch.Tensor:
38
+ if isinstance(images, Image.Image):
39
+ images = [images]
40
+ image_tensors = [preprocess(img).unsqueeze(0) for img in images]
41
+ image_batch = torch.cat(image_tensors).to(device)
42
+ image_features = model2.encode_image(image_batch)
43
+ im_emb_arr = normalized(image_features).cpu().float()
44
+ return im_emb_arr
45
+
46
+
47
+ class WaifuScorer:
48
+ def __init__(self, model_path: str = WAIFU_FILTER_V1_MODEL_REPO, model_type='mlp', device: str = None, dtype=torch.float32):
49
+ print(f"loading model from `{model_path}`...")
50
+ device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
51
+ self.mlp = load_model(model_path, model_type=model_type, input_size=768, device=device, dtype=dtype)
52
+ self.mlp.eval()
53
+ self.model2, self.preprocess = clip.load("ViT-L/14", device=device)
54
+ self.device = self.mlp.device
55
+ self.dtype = self.mlp.dtype
56
+ print(f"model loaded: cls={model_type} | device={self.device} | dtype={self.dtype}")
57
+
58
+ @torch.no_grad()
59
+ def predict(self, images: List[Image.Image]) -> float:
60
+ images = encode_images(images, self.model2, self.preprocess, device=self.device).to(device=self.device, dtype=self.dtype)
61
+ predictions = self.mlp(images)
62
+ scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
63
+ return scores
waifu_scorer/train.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # os.environ['CUDA_VISIBLE_DEVICES'] = "0" # in case you are using a multi GPU workstation, choose your GPU here
2
+
3
+ import os
4
+ import torch
5
+ import random
6
+ import torch.nn as nn
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ from accelerate import Accelerator
10
+ from typing import Literal, Callable, Optional, Union
11
+ from waifuset.utils import log_utils
12
+ from waifuset.classes import Dataset, ImageInfo
13
+ from . import mlp, utils, train_utils
14
+
15
+ StrPath = Union[str, Path]
16
+
17
+
18
+ def train(
19
+ dataset_source,
20
+ save_path,
21
+ resume_path: StrPath = None,
22
+ data_preprocessor: Optional[Callable[[ImageInfo], float]] = None,
23
+ rating_func_type: Union[Callable[[ImageInfo], float], Literal['direct', 'label', 'quality']] = 'quality',
24
+ num_train_epochs=50,
25
+ learning_rate=1e-3,
26
+ train_batch_size=256,
27
+ shuffle=True,
28
+ flip_aug=True,
29
+ val_batch_size=512,
30
+ val_every_n_epochs=1,
31
+ val_percentage=0.05, # 5% of the training data will be used for validation
32
+ save_best_model=True,
33
+ clip_batch_size=1,
34
+ cache_to_disk: bool = False,
35
+ cache_path: StrPath = None,
36
+ mixed_precision=None,
37
+ max_data_loader_n_workers: int = 4,
38
+ persistent_workers=False,
39
+ mlp_model_type: Literal['default', 'large'] = 'default',
40
+ clip_model_name: str = "ViT-L/14",
41
+ input_size: int = 768,
42
+ batch_norm: bool = True,
43
+ ):
44
+ r"""
45
+ :param dataset_source: any dataset source, e.g. path to the dataset.
46
+ :param save_path: path to save the trained model.
47
+ :param resume_path: path to the model to resume from.
48
+ :param cache_to_disk: whether to cache the training data to disk.
49
+ :param cache_path: path to the cached training data. If not exists, will be created from `dataset_source`. If exists, will be loaded from disk.
50
+ :param num_train_epochs: number of training epochs.
51
+ :param learning_rate: learning rate.
52
+ :param train_batch_size: training batch size.
53
+ :param val_batch_size: validation batch size.
54
+ :param val_every_n_epochs: validation frequency.
55
+ :param val_percentage: percentage of the training data to be used for validation.
56
+ :param encoder_batch_size: batch size for encoding images.
57
+ :param mixed_precision: whether to use mixed precision training.
58
+ :param max_data_loader_n_workers: maximum number of workers for data loaders.
59
+ :param persistent_workers: whether to use persistent workers for data loaders.
60
+ :param input_size: input size of the model.
61
+ """
62
+ log_utils.info(f"prepare for training")
63
+ accelerator = Accelerator(mixed_precision=mixed_precision)
64
+ weight_dtype = train_utils.prepare_dtype(mixed_precision)
65
+ device = accelerator.device
66
+ max_data_loader_n_workers = min(max_data_loader_n_workers, os.cpu_count()-1)
67
+ if callable(rating_func_type):
68
+ rating_func = rating_func_type
69
+ else:
70
+ rating_func = train_utils.get_rating_func(rating_func_type)
71
+
72
+ model2, preprocess = utils.load_clip_models(name=clip_model_name, device=device) # RN50x64
73
+
74
+ dataset = Dataset(dataset_source, verbose=True, condition=lambda img_info: img_info.image_path.is_file())
75
+ if data_preprocessor:
76
+ for img_key, img_info in dataset.items():
77
+ img_info = data_preprocessor(img_info)
78
+ keys = list(dataset.keys())
79
+ random.shuffle(keys)
80
+ dataset = Dataset({k: dataset[k] for k in keys})
81
+
82
+ num_pos = 0
83
+ num_neg = 0
84
+ num_mid = 0
85
+ for img_key, img_info in dataset.items():
86
+ rating = rating_func(img_info)
87
+ if rating == 10:
88
+ num_pos += 1
89
+ elif rating == 0:
90
+ num_neg += 1
91
+ else:
92
+ num_mid += 1
93
+ log_utils.info(f"num_pos: {num_pos} | num_mid: {num_mid} | num_neg: {num_neg}")
94
+
95
+ train_size = int(len(dataset) * (1 - val_percentage))
96
+ val_size = len(dataset) - train_size
97
+ train_dataset, val_dataset = Dataset(dataset.values()[:train_size]), Dataset(dataset.values()[train_size:])
98
+
99
+ log_utils.info(f"train_size: {train_size} | val_size: {val_size}")
100
+
101
+ train_dataset, train_loader = train_utils.prepare_dataloader(
102
+ train_dataset,
103
+ batch_size=train_batch_size,
104
+ clip_batch_size=clip_batch_size,
105
+ model2=model2,
106
+ preprocess=preprocess,
107
+ input_size=input_size,
108
+ rating_func=rating_func,
109
+ shuffle=shuffle,
110
+ flip_aug=flip_aug,
111
+ cache_to_disk=cache_to_disk,
112
+ cache_path=cache_path,
113
+ max_data_loader_n_workers=max_data_loader_n_workers,
114
+ persistent_workers=persistent_workers,
115
+ device=device,
116
+ )
117
+
118
+ val_dataset, val_loader = train_utils.prepare_dataloader(
119
+ val_dataset,
120
+ batch_size=val_batch_size,
121
+ clip_batch_size=clip_batch_size,
122
+ model2=model2,
123
+ preprocess=preprocess,
124
+ rating_func=rating_func,
125
+ shuffle=shuffle,
126
+ flip_aug=flip_aug,
127
+ cache_to_disk=cache_to_disk,
128
+ cache_path=cache_path,
129
+ max_data_loader_n_workers=max_data_loader_n_workers,
130
+ persistent_workers=persistent_workers,
131
+ device=device,
132
+ )
133
+
134
+ rating_stat = {}
135
+ for i in range(len(train_dataset)):
136
+ # to list
137
+ ratings: torch.Tensor = train_dataset[i]['ratings']
138
+ ratings = ratings.squeeze().tolist()
139
+ for rating in ratings:
140
+ if rating not in rating_stat:
141
+ rating_stat[rating] = 0
142
+ rating_stat[rating] += 1
143
+
144
+ log_utils.info("rating_stat:\n", '\n'.join(f'{k}: {v}' for k, v in rating_stat.items()))
145
+
146
+ # prepare model
147
+
148
+ model: mlp.MLP = utils.load_model(resume_path, model_type=mlp_model_type, input_size=input_size, batch_norm=batch_norm, device=device, dtype=weight_dtype)
149
+
150
+ # import prodigyopt
151
+ # print(f"use Prodigy optimizer | {optimizer_kwargs}")
152
+ # optimizer_class = prodigyopt.Prodigy
153
+ # optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
154
+
155
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
156
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
157
+
158
+ # choose the loss you want to optimize for
159
+ criterion = nn.MSELoss(reduction='mean')
160
+ criterion2 = nn.L1Loss(reduction='mean')
161
+
162
+ model, optimizer, train_loader, val_loader = accelerator.prepare(
163
+ model, optimizer, train_loader, val_loader
164
+ )
165
+
166
+ log_utils.info(f"device: {accelerator.device}")
167
+
168
+ # training loop
169
+ best_loss = 999 # best validation loss
170
+ total_train_steps = len(train_loader) * num_train_epochs
171
+ progress_bar = tqdm(range(total_train_steps), position=0, leave=True)
172
+ print(f"total_train_steps: {total_train_steps}")
173
+
174
+ class LossRecorder:
175
+ def __init__(self):
176
+ self.loss_list = []
177
+ self.loss_total: float = 0.0
178
+
179
+ def add(self, *, epoch: int, step: int, loss: float) -> None:
180
+ if epoch == 0:
181
+ self.loss_list.append(loss)
182
+ else:
183
+ self.loss_total -= self.loss_list[step]
184
+ self.loss_list[step] = loss
185
+ self.loss_total += loss
186
+
187
+ @property
188
+ def moving_average(self) -> float:
189
+ return self.loss_total / len(self.loss_list)
190
+
191
+ loss_recorder = LossRecorder()
192
+ model.requires_grad_(True)
193
+ save_on_end = False
194
+
195
+ try:
196
+ for epoch in range(num_train_epochs):
197
+ model.train()
198
+ losses = []
199
+ losses2 = []
200
+ for step, input_data in enumerate(train_loader):
201
+ optimizer.zero_grad(set_to_none=True)
202
+ im_emb_arr: torch.Tensor = input_data['im_emb_arrs'].to(accelerator.device).to(dtype=weight_dtype) # shape: (batch_size, input_size)
203
+ rating: torch.Tensor = input_data['ratings'].to(accelerator.device).to(dtype=weight_dtype) # shape: (batch_size, 1)
204
+
205
+ # randomize the rating
206
+ # rating_std = 0.5
207
+ # rating = rating + torch.randn_like(rating) * rating_std
208
+
209
+ # log_utils.debug(f"x.dtype: {x.dtype} | y.dtype: {y.dtype} | model.dtype: {model.dtype}")
210
+
211
+ with accelerator.autocast():
212
+ output = model(im_emb_arr)
213
+
214
+ loss = criterion(output, rating)
215
+
216
+ accelerator.backward(loss)
217
+
218
+ losses.append(loss.detach().item())
219
+
220
+ optimizer.step()
221
+
222
+ # if step % 1000 == 0:
223
+ # print('\tEpoch %d | Batch %d | Loss %6.2f' % (epoch, step, loss.item()))
224
+ # # print(y)
225
+
226
+ progress_bar.update(1)
227
+
228
+ current_loss = loss.detach().item()
229
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
230
+ avr_loss: float = loss_recorder.moving_average
231
+ pbar_logs = {
232
+ 'lr': f"{lr_scheduler.get_last_lr()[0]:.3e}",
233
+ 'epoch': epoch,
234
+ 'loss': avr_loss,
235
+ }
236
+ progress_bar.set_postfix(pbar_logs)
237
+
238
+ progress_bar.write('epoch %d | avg loss %6.6f' % (epoch, avr_loss))
239
+
240
+ # validation
241
+ if accelerator.is_main_process and epoch > 0 and epoch % val_every_n_epochs == 0:
242
+ model.eval()
243
+ with torch.no_grad():
244
+ losses = []
245
+ losses2 = []
246
+ for step, input_data in enumerate(val_loader):
247
+ # optimizer.zero_grad(set_to_none=True)
248
+ im_emb_arr = input_data['im_emb_arrs'].to(accelerator.device).to(dtype=weight_dtype)
249
+ rating = input_data['ratings'].to(accelerator.device).to(dtype=weight_dtype)
250
+
251
+ with accelerator.autocast():
252
+ output = model(im_emb_arr)
253
+ loss = criterion(output, rating)
254
+ lossMAE = criterion2(output, rating)
255
+ # loss.backward()
256
+ losses.append(loss.detach().item())
257
+ losses2.append(lossMAE.detach().item())
258
+ # optimizer.step()
259
+
260
+ # if step % 1000 == 0:
261
+ # print('\tValidation - Epoch %d | Batch %d | MSE Loss %6.2f' % (epoch, step, loss.item()))
262
+ # print('\tValidation - Epoch %d | Batch %d | MAE Loss %6.2f' % (epoch, step, lossMAE.item()))
263
+
264
+ # print(y)
265
+ current_loss = sum(losses)/len(losses)
266
+ s = [f"validation - epoch {log_utils.stylize(epoch, log_utils.ANSI.YELLOW)}"]
267
+ s.append(f"avg MSE loss {log_utils.stylize(current_loss, log_utils.ANSI.GREEN, format_spec='.4f')}")
268
+ s.append(f"avg MAE loss {log_utils.stylize(sum(losses2)/len(losses2), log_utils.ANSI.YELLOW, format_spec='.4f')}")
269
+ progress_bar.write(' | '.join(s))
270
+ # progress_bar.write('validation - epoch %d | avg MSE loss %6.4f' % (epoch, sum(losses)/len(losses)))
271
+ # progress_bar.write('validation - epoch %d | avg MAE loss %6.4f' % (epoch, sum(losses2)/len(losses2)))
272
+
273
+ if save_best_model and current_loss < best_loss:
274
+ best_loss = current_loss
275
+ progress_bar.write(f"best MSE val loss ({log_utils.stylize(best_loss, log_utils.ANSI.BOLD, log_utils.ANSI.GREEN)}) so far. saving model...")
276
+ best_save_path = Path(save_path).parent / f"{Path(save_path).stem}_best-MSE{best_loss:.4f}{Path(save_path).suffix}"
277
+ train_utils.save_model(model, best_save_path, epoch=epoch)
278
+ progress_bar.write(f"model saved: `{save_path}`")
279
+
280
+ lr_scheduler.step()
281
+ accelerator.wait_for_everyone()
282
+ except KeyboardInterrupt:
283
+ log_utils.warn("KeyboardInterrupt")
284
+ if input(f"save model to {save_path}? [y/n]") == 'y':
285
+ save_on_end = True
286
+ else:
287
+ save_on_end = True
288
+
289
+ progress_bar.close()
290
+ model = accelerator.unwrap_model(model)
291
+ accelerator.wait_for_everyone()
292
+
293
+ if accelerator.is_main_process and save_on_end:
294
+ log_utils.info("saving model...")
295
+ train_utils.save_model(model, save_path)
296
+ log_utils.info(f"model saved: `{save_path}`")
297
+
298
+ del accelerator
299
+
300
+ log_utils.success(f"training done. best loss: {best_loss}")
301
+
302
+ # inferece test with dummy samples from the val set, sanity check
303
+ # log_utils.info("inference test with dummy samples from the val set, sanity check")
304
+ # model.eval()
305
+ # output = model(x[:5].to(device))
306
+ # log_utils.info(output.size())
307
+ # log_utils.info(output)
waifu_scorer/train_utils.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import h5py
4
+ import math
5
+ import random
6
+ from torch.utils.data import DataLoader
7
+ from pathlib import Path
8
+ from typing import List, Callable, Tuple
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ from waifuset.classes import Dataset, ImageInfo
12
+ from waifuset.utils import log_utils
13
+ from .utils import encode_images, load_clip_models, quality_rating
14
+
15
+
16
+ class LaionImageInfo:
17
+ def __init__(
18
+ self,
19
+ img_path=None,
20
+ im_emb_arr=None,
21
+ rating=None,
22
+ im_emb_arr_flipped=None,
23
+ num_repeats=1,
24
+ ):
25
+ self.img_path = img_path
26
+ self.im_emb_arr = im_emb_arr
27
+ self.rating = rating
28
+ self.im_emb_arr_flipped = im_emb_arr_flipped
29
+ self.num_repeats = num_repeats
30
+
31
+
32
+ class LaionDataset:
33
+ def __init__(
34
+ self,
35
+ source,
36
+ cache_to_disk=True,
37
+ cache_path=None,
38
+ batch_size=1,
39
+ clip_batch_size=4,
40
+ model2=None,
41
+ preprocess=None,
42
+ input_size=768,
43
+ rating_func: Callable = quality_rating,
44
+ repeating_func: Callable = None,
45
+ shuffle=True,
46
+ flip_aug: bool = True,
47
+ device='cuda'
48
+ ):
49
+ if model2 is None or preprocess is None:
50
+ model2, preprocess = load_clip_models(device) # RN50x64
51
+ if cache_to_disk and cache_path is None:
52
+ raise ValueError("cache_path must be specified when cache_to_disk is True.")
53
+ self.source = source
54
+ self.cache_to_disk = cache_to_disk
55
+ self.cache_path = Path(cache_path)
56
+ self.model2, self.preprocess = model2, preprocess
57
+ self.input_size = input_size
58
+ self.rating_func = rating_func
59
+ self.batch_size = batch_size
60
+ self.encoder_batch_size = clip_batch_size
61
+ self.shuffle = shuffle
62
+ self.flip_aug = flip_aug
63
+ self.device = device
64
+
65
+ dataset: Dataset = Dataset(source, verbose=True)
66
+
67
+ self.image_data = []
68
+
69
+ for img_key, img_info in tqdm(dataset.items(), desc='prepare dataset'):
70
+ img_path = img_info.image_path
71
+ rating = self.rating_func(img_info)
72
+ laion_image_info = LaionImageInfo(
73
+ img_path=img_path,
74
+ rating=rating,
75
+ )
76
+ self.register_image_info(laion_image_info)
77
+
78
+ rating_counter = {}
79
+ for laion_img_info in tqdm(self.image_data, desc='calculating num repeats (1/2)'):
80
+ # to list
81
+ rating: torch.Tensor = laion_img_info.rating
82
+ rating_counter.setdefault(rating, 0)
83
+ rating_counter[rating] += 1
84
+
85
+ for laion_img_info in tqdm(self.image_data, desc='calculating num repeats (2/2)'):
86
+ benchmark = 30000
87
+ num_repeats = benchmark / rating_counter[laion_img_info.rating]
88
+ prob = num_repeats - math.floor(num_repeats)
89
+ num_repeats = math.floor(num_repeats) if random.random() < prob else math.ceil(num_repeats)
90
+ laion_img_info.num_repeats = max(1, num_repeats)
91
+
92
+ self.cache_embs()
93
+ self.batches = self.make_batches()
94
+
95
+ def register_image_info(self, image_info: LaionImageInfo):
96
+ self.image_data.append(image_info)
97
+
98
+ def cache_embs(self):
99
+ self.cache_path.parent.mkdir(parents=True, exist_ok=True)
100
+
101
+ not_cached = [] # list of (image_info, flipped)
102
+ num_cached = 0
103
+
104
+ # load cache
105
+ if self.cache_to_disk:
106
+ pbar = tqdm(total=len(self.image_data), desc='loading cache')
107
+
108
+ def load_cached_emb(h5, image_info: LaionImageInfo, flipped=False):
109
+ nonlocal num_cached
110
+ image_key = image_info.img_path.stem
111
+ if flipped:
112
+ image_key = image_key + '_flipped'
113
+ if image_key in h5:
114
+ im_emb_arr = torch.from_numpy(f[image_key][:])
115
+ if im_emb_arr.shape[-1] != self.input_size:
116
+ raise ValueError(f"Input size mismatched. Except {self.input_size} dim, but got {im_emb_arr.shape[-1]} dim loaded. Please check your cache file.")
117
+ assert im_emb_arr.device == torch.device('cpu'), "flipped image emb should be on cpu"
118
+ if flipped:
119
+ image_info.im_emb_arr_flipped = im_emb_arr
120
+ else:
121
+ image_info.im_emb_arr = im_emb_arr
122
+ num_cached += 1
123
+ else:
124
+ not_cached.append((image_info, flipped))
125
+
126
+ if not is_h5_file(self.cache_path):
127
+ # create cache
128
+ log_utils.info(f"cache file not found, creating new cache file: {self.cache_path}")
129
+ with h5py.File(self.cache_path, 'w') as f:
130
+ pass
131
+ else:
132
+ log_utils.info(f"loading cache file: {self.cache_path}")
133
+ with h5py.File(self.cache_path, 'r') as f:
134
+ for image_info in self.image_data:
135
+ load_cached_emb(f, image_info, flipped=False)
136
+ if self.flip_aug:
137
+ load_cached_emb(f, image_info, flipped=True)
138
+ pbar.update()
139
+ pbar.close()
140
+ else:
141
+ not_cached = [(image_info, False) for image_info in self.image_data]
142
+ if self.flip_aug:
143
+ not_cached += [(image_info, True) for image_info in self.image_data]
144
+
145
+ # encode not-cached images
146
+ if len(not_cached) == 0:
147
+ log_utils.info("all images are cached.")
148
+ else:
149
+ log_utils.info(f"number of cached instances: {num_cached}")
150
+ log_utils.info(f"number of not cached instances: {len(not_cached)}")
151
+
152
+ batches = [not_cached[i:i + self.encoder_batch_size] for i in range(0, len(not_cached), self.encoder_batch_size)]
153
+ pbar = tqdm(total=len(batches), desc='encoding images')
154
+
155
+ def cache_batch_embs(h5, batch: List[Tuple[LaionImageInfo, bool]]):
156
+ try:
157
+ images = [Image.open(image_info.img_path) if not flipped else Image.open(image_info.img_path).transpose(Image.FLIP_LEFT_RIGHT) for image_info, flipped in batch]
158
+ except:
159
+ log_utils.error(f"Error occurred when loading one of the images: {[image_info.img_path for image_info, flipped in batch]}")
160
+ raise
161
+ im_emb_arrs = encode_images(images, self.model2, self.preprocess, device=self.device) # shape: [batch_size, input_size]
162
+ for i, item in enumerate(batch):
163
+ image_info, flipped = item
164
+ im_emb_arr = im_emb_arrs[i]
165
+ shape_size = len(im_emb_arr.shape)
166
+ if shape_size == 1:
167
+ im_emb_arr = im_emb_arr.unsqueeze(0)
168
+ elif shape_size == 3:
169
+ im_emb_arr = im_emb_arr.squeeze(1)
170
+
171
+ image_key = image_info.img_path.stem
172
+ assert im_emb_arr.device == torch.device('cpu'), "flipped image emb should be on cpu"
173
+ if flipped:
174
+ image_key = image_key + '_flipped'
175
+ image_info.im_emb_arr_flipped = im_emb_arr
176
+ else:
177
+ image_info.im_emb_arr = im_emb_arr
178
+
179
+ if self.cache_to_disk:
180
+ if image_key in h5:
181
+ continue
182
+ h5.create_dataset(image_key, data=im_emb_arr.cpu().numpy())
183
+
184
+ try:
185
+ h5 = h5py.File(self.cache_path, 'a') if self.cache_to_disk else None
186
+ for batch in batches:
187
+ cache_batch_embs(h5, batch)
188
+ pbar.update()
189
+ finally:
190
+ if h5:
191
+ h5.close()
192
+ pbar.close()
193
+
194
+ def make_batches(self):
195
+ batches = []
196
+ repeated_image_data = []
197
+ for image_info in self.image_data:
198
+ repeated_image_data += [image_info] * image_info.num_repeats
199
+ log_utils.info(f"number of instances (repeated): {len(repeated_image_data)}")
200
+ for i in range(0, len(repeated_image_data), self.batch_size):
201
+ batch = repeated_image_data[i:i + self.batch_size]
202
+ batches.append(batch)
203
+ if self.shuffle:
204
+ random.shuffle(batches)
205
+ return batches
206
+
207
+ def __getitem__(self, index):
208
+ batch = self.batches[index]
209
+ im_emb_arrs = []
210
+ ratings = []
211
+ for image_info in batch:
212
+ flip = self.flip_aug and random.random() > 0.5
213
+ if not flip:
214
+ im_emb_arr = image_info.im_emb_arr
215
+ else:
216
+ im_emb_arr = image_info.im_emb_arr_flipped
217
+ rating = image_info.rating
218
+
219
+ im_emb_arrs.append(im_emb_arr)
220
+ ratings.append(rating)
221
+
222
+ im_emb_arrs = torch.cat(im_emb_arrs, dim=0)
223
+ ratings = torch.tensor(ratings).unsqueeze(-1)
224
+ sample = dict(
225
+ im_emb_arrs=im_emb_arrs,
226
+ ratings=ratings,
227
+ )
228
+ return sample
229
+
230
+ def __len__(self):
231
+ return len(self.batches)
232
+
233
+
234
+ def collate_fn(batch):
235
+ return batch[0]
236
+
237
+
238
+ def get_rating_func(rating_func_type: str):
239
+ if rating_func_type == 'quality':
240
+ from .utils import quality_rating
241
+ rating_func = quality_rating
242
+ else:
243
+ raise ValueError(f"Invalid rating type: {rating_func_type}")
244
+ return rating_func
245
+
246
+
247
+ def prepare_dataloader(
248
+ dataset_source,
249
+ cache_to_disk=True,
250
+ cache_path=None,
251
+ batch_size=1,
252
+ clip_batch_size=4,
253
+ model2=None,
254
+ preprocess=None,
255
+ input_size=768,
256
+ rating_func: Callable = quality_rating,
257
+ shuffle=True,
258
+ flip_aug: bool = True,
259
+ device='cuda',
260
+ persistent_workers=False,
261
+ max_data_loader_n_workers=0,
262
+ ):
263
+ dataset = LaionDataset(
264
+ dataset_source,
265
+ cache_to_disk=cache_to_disk,
266
+ cache_path=cache_path,
267
+ batch_size=batch_size,
268
+ clip_batch_size=clip_batch_size,
269
+ model2=model2,
270
+ preprocess=preprocess,
271
+ input_size=input_size,
272
+ rating_func=rating_func,
273
+ shuffle=shuffle,
274
+ flip_aug=flip_aug,
275
+ device=device,
276
+ )
277
+
278
+ dataloader = DataLoader(
279
+ dataset,
280
+ batch_size=1, # fix to 1
281
+ shuffle=shuffle,
282
+ num_workers=max_data_loader_n_workers,
283
+ persistent_workers=persistent_workers,
284
+ collate_fn=collate_fn,
285
+ )
286
+
287
+ return dataset, dataloader
288
+
289
+
290
+ def is_h5_file(cache_path):
291
+ if not cache_path or not h5py.is_hdf5(cache_path):
292
+ return False
293
+ return True
294
+
295
+
296
+ # def make_train_data(
297
+ # dataset_source,
298
+ # rating_func: Callable = quality_rating,
299
+ # batch_size=1,
300
+ # flip_aug: bool = True,
301
+ # device='cuda'
302
+ # ):
303
+ # model2, preprocess = clip.load("ViT-L/14", device=device) # RN50x64
304
+ # dataset = Dataset.from_source(dataset_source, verbose=True)
305
+ # x_train = []
306
+ # y_train = []
307
+ # batches = [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)]
308
+ # for batch in tqdm(batches, desc='encoding images', smoothing=1):
309
+ # im_emb_arr = encode_images([d.pil_img for d in batch], model2, preprocess, device=device) # shape: [batch_size, 768]
310
+ # ratings = torch.tensor([rating_func(data) for data in batch]).unsqueeze(-1).to(device) # shape: [batch_size, 1]
311
+ # x_train.append(im_emb_arr)
312
+ # y_train.append(ratings)
313
+ # x_train = torch.cat(x_train, dim=0)
314
+ # y_train = torch.cat(y_train, dim=0)
315
+ # return x_train, y_train
316
+
317
+
318
+ def prepare_dtype(mixed_precision: str):
319
+ weight_dtype = torch.float32
320
+ if mixed_precision == "fp16":
321
+ weight_dtype = torch.float16
322
+ elif mixed_precision == "bf16":
323
+ weight_dtype = torch.bfloat16
324
+ return weight_dtype
325
+
326
+
327
+ def save_model(model, save_path, epoch=None):
328
+ save_path = str(save_path)
329
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
330
+ if epoch is not None:
331
+ save_path = save_path.replace('.pth', f'_ep{epoch}.pth')
332
+ torch.save(model.state_dict(), save_path)
333
+ return save_path
waifu_scorer/ui.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from argparse import ArgumentParser
3
+
4
+
5
+ def parse_args():
6
+ parser = ArgumentParser()
7
+ parser.add_argument(
8
+ '--model_path',
9
+ type=str,
10
+ default='./model/v3.pth',
11
+ help='Path or url to the model file',
12
+ )
13
+ parser.add_argument(
14
+ '--model_type',
15
+ type=str,
16
+ default='mlp',
17
+ help='Type of the model',
18
+ )
19
+ parser.add_argument(
20
+ '--fix_model_path',
21
+ action='store_true',
22
+ help='Fix the model path',
23
+ )
24
+ parser.add_argument(
25
+ '--device',
26
+ type=str,
27
+ default='cuda',
28
+ help='Device to use',
29
+ )
30
+ parser.add_argument(
31
+ '--share',
32
+ action='store_true',
33
+ help='Share the demo',
34
+ )
35
+ return parser.parse_args()
36
+
37
+
38
+ def ui(args):
39
+ from waifu_scorer.predict import WaifuScorer, load_model
40
+ scorer = WaifuScorer(
41
+ model_path=args.model_path,
42
+ model_type=args.model_type,
43
+ device=args.device,
44
+ )
45
+
46
+ with gr.Blocks() as demo:
47
+ with gr.Row():
48
+ with gr.Column():
49
+ image = gr.Image(
50
+ label='Image',
51
+ type='pil',
52
+ height=512,
53
+ sources=['upload', 'clipboard'],
54
+ )
55
+ with gr.Column():
56
+ with gr.Row():
57
+ model_path = gr.Textbox(
58
+ label='Model Path',
59
+ value=args.model_path,
60
+ placeholder='Path or URL to the model file',
61
+ interactive=not args.fix_model_path,
62
+ )
63
+ with gr.Row():
64
+ score = gr.Number(
65
+ label='Score',
66
+ )
67
+
68
+ def change_model(model_path):
69
+ nonlocal scorer
70
+ scorer.mlp = load_model(model_path, model_type=args.model_type, device=args.device)
71
+ print(f"Model changed to `{model_path}`")
72
+ return gr.update()
73
+
74
+ model_path.submit(
75
+ fn=change_model,
76
+ inputs=model_path,
77
+ outputs=model_path,
78
+ )
79
+
80
+ image.change(
81
+ fn=lambda image: scorer.predict([image]*2)[0] if image is not None else None,
82
+ inputs=image,
83
+ outputs=score,
84
+ )
85
+
86
+ return demo
87
+
88
+
89
+ def launch(args):
90
+ demo = ui(args)
91
+ demo.launch(share=args.share)
waifu_scorer/utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ from PIL import Image
4
+ from typing import List, Union
5
+ from . import mlp
6
+
7
+ QUALITY_TO_RATING = {
8
+ 'amazing': 10,
9
+ 'best': 8.5,
10
+ 'high': 7,
11
+ 'normal': 5,
12
+ 'low': 2.5,
13
+ 'worst': 0,
14
+ 'horrible': 0,
15
+ }
16
+
17
+ MODEL_TYPE = {
18
+ 'mlp': mlp.MLP,
19
+ 'res_mlp': mlp.ResMLP,
20
+ }
21
+
22
+
23
+ def quality_rating(img_info):
24
+ quality = (img_info.caption.quality or 'normal') if img_info.caption is not None else 'normal'
25
+ rating = QUALITY_TO_RATING[quality]
26
+ return rating
27
+
28
+
29
+ def get_model_cls(model_type) -> Union[mlp.MLP, None]:
30
+ return MODEL_TYPE.get(model_type, mlp.MLP)
31
+
32
+
33
+ def load_clip_models(name: str = "ViT-L/14", device='cuda'):
34
+ model2, preprocess = clip.load(name, device=device) # RN50x64
35
+ return model2, preprocess
36
+
37
+
38
+ def load_model(model_path: str = None, model_type=None, input_size=768, batch_norm: bool = True, device: str = 'cuda', dtype=None):
39
+ model_cls = get_model_cls(model_type)
40
+ print(f"Loading model from class `{model_cls}`...")
41
+ model_kwargs = {}
42
+ if model_type in ('large', 'res_large'):
43
+ model_kwargs['batch_norm'] = True
44
+ model = model_cls(input_size, **model_kwargs)
45
+ if model_path:
46
+ try:
47
+ s = torch.load(model_path, map_location=device)
48
+ model.load_state_dict(s)
49
+ except Exception as e:
50
+ print(f"Model type mismatch. Desired model type: `{model_type}` (model class: `{model_cls}`).")
51
+ raise e
52
+ model.to(device)
53
+ if dtype:
54
+ model = model.to(dtype=dtype)
55
+ return model
56
+
57
+
58
+ def normalized(a: torch.Tensor, order=2, dim=-1):
59
+ l2 = a.norm(order, dim, keepdim=True)
60
+ l2[l2 == 0] = 1
61
+ return a / l2
62
+
63
+
64
+ @torch.no_grad()
65
+ def encode_images(images: List[Image.Image], model2, preprocess, device='cuda') -> torch.Tensor:
66
+ if isinstance(images, Image.Image):
67
+ images = [images]
68
+ image_tensors = [preprocess(img).unsqueeze(0) for img in images]
69
+ image_batch = torch.cat(image_tensors).to(device)
70
+ image_features = model2.encode_image(image_batch)
71
+ im_emb_arr = normalized(image_features).cpu().float()
72
+ return im_emb_arr