City commited on
Commit
bb0a0a7
1 Parent(s): dbfacdc

Sync with github

Browse files
Files changed (5) hide show
  1. README.md +5 -6
  2. demo_class_gradio.py +62 -0
  3. inference.py +236 -0
  4. model.py +45 -0
  5. requirements.txt +4 -0
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: AnimeClassifiers Demo
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
- app_file: app.py
 
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: AnimeClassifiers Demo
3
+ emoji: 🧱
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
+ app_file: demo_class_gradio.py
9
+ models: [city96/AnimeClassifiers]
10
  pinned: false
11
  license: apache-2.0
12
  ---
 
 
demo_class_gradio.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+
5
+ from inference import CityClassifierMultiModelPipeline, get_model_path
6
+
7
+ TOKEN = os.environ.get("HFS_TOKEN")
8
+ HFREPO = "City96/AnimeClassifiers"
9
+ MODELS = [
10
+ "CCAnime-ChromaticAberration-v1.16",
11
+ ]
12
+ article = """\
13
+ These are classifiers meant to work with anime images.
14
+
15
+ For more information, you can check out the [Huggingface Hub](https://huggingface.co/city96/AnimeClassifiers) or [GitHub page](https://github.com/city96/CityClassifiers).
16
+ """
17
+ info_default="""\
18
+ Include default class (unknown/negative) in output results.
19
+ """
20
+ info_tiling = """\
21
+ Divide the image into parts and run classifier on each part separately.
22
+ Greatly improves accuracy but slows down inference.
23
+ """
24
+ info_tiling_combine = """\
25
+ How to combine the confidence scores of the different tiles.
26
+ Mean averages confidence over all tiles. Median takes the value in the middle.
27
+ Max/min take the score from the tile with the highest/lowest confidence respectively, but can results in multiple labels having very high/very low confidence scores.
28
+ """
29
+
30
+ pipeline_args = {}
31
+ if torch.cuda.is_available():
32
+ pipeline_args.update({
33
+ "device" : "cuda",
34
+ "clip_dtype" : torch.float16,
35
+ })
36
+
37
+ pipeline = CityClassifierMultiModelPipeline(
38
+ model_paths = [get_model_path(x, HFREPO, TOKEN) for x in MODELS],
39
+ config_paths = [get_model_path(x, HFREPO, TOKEN, extension="config.json") for x in MODELS],
40
+ **pipeline_args,
41
+ )
42
+ gr.Interface(
43
+ fn = pipeline,
44
+ title = "CityClassifiers demo",
45
+ article = article,
46
+ inputs = [
47
+ gr.Image(label="Input image", type="pil"),
48
+ gr.Checkbox(label="Include default", value=True, info=info_default),
49
+ gr.Checkbox(label="Tiling", value=True, info=info_tiling),
50
+ gr.Dropdown(
51
+ label = "Tiling combine strategy",
52
+ choices = ["mean", "median", "max", "min"],
53
+ value = "mean",
54
+ type = "value",
55
+ info = info_tiling_combine,
56
+ )
57
+ ],
58
+ outputs = [gr.Label(label=x) for x in MODELS],
59
+ examples = "./examples" if os.path.isdir("./examples") else None,
60
+ allow_flagging = "never",
61
+ analytics_enabled = False,
62
+ ).launch()
inference.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision.transforms as TF
5
+ from safetensors.torch import load_file
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
8
+
9
+ from model import PredictorModel
10
+
11
+ class CityAestheticsPipeline:
12
+ """
13
+ Demo model pipeline for [image=>score] prediction
14
+ Accepts a single model path on initialization.
15
+ Resulting object can be called directly with a PIL image as the input
16
+ Returns a single float value with the predicted score [0.0;1.0].
17
+ """
18
+ clip_ver = "openai/clip-vit-large-patch14"
19
+ def __init__(self, model_path, device="cpu", clip_dtype=torch.float32):
20
+ self.device = device
21
+ self.clip_dtype = clip_dtype
22
+ self._init_clip()
23
+ self.model = self._load_model(model_path)
24
+ print("CityAesthetics: Pipeline init ok") # debug
25
+
26
+ def __call__(self, raw):
27
+ emb = self.get_clip_emb(raw)
28
+ return self.get_model_pred(self.model, emb)
29
+
30
+ def get_model_pred(self, model, emb):
31
+ with torch.no_grad():
32
+ pred = model(emb)
33
+ return float(pred.detach().cpu().squeeze(0))
34
+
35
+ def get_clip_emb(self, raw):
36
+ img = self.proc(
37
+ images = raw,
38
+ return_tensors = "pt"
39
+ )["pixel_values"].to(self.clip_dtype).to(self.device)
40
+ with torch.no_grad():
41
+ emb = self.clip(pixel_values=img)
42
+ return emb["image_embeds"].detach().to(torch.float32)
43
+
44
+ def _init_clip(self):
45
+ self.proc = CLIPImageProcessor.from_pretrained(self.clip_ver)
46
+ self.clip = CLIPVisionModelWithProjection.from_pretrained(
47
+ self.clip_ver,
48
+ device_map = self.device,
49
+ torch_dtype = self.clip_dtype,
50
+ )
51
+
52
+ def _load_model(self, path):
53
+ sd = load_file(path)
54
+ assert tuple(sd["up.0.weight"].shape) == (1024, 768) # only allow CLIP ver
55
+ model = PredictorModel(outputs=1)
56
+ model.eval()
57
+ model.load_state_dict(sd)
58
+ model.to(self.device)
59
+ return model
60
+
61
+ class CityAestheticsMultiModelPipeline(CityAestheticsPipeline):
62
+ """
63
+ Demo multi-model pipeline for [image=>score] prediction
64
+ Accepts a list of model paths on initialization.
65
+ Resulting object can be called directly with a PIL image as the input.
66
+ Returns a dict with the model name as key and the score [0.0;1.0] as a value.
67
+ """
68
+ def __init__(self, model_paths, device="cpu", clip_dtype=torch.float32):
69
+ self.device = device
70
+ self.clip_dtype = clip_dtype
71
+ self._init_clip()
72
+ self.models = {}
73
+ for path in model_paths:
74
+ name = os.path.splitext(os.path.basename(path))[0]
75
+ self.models[name] = self._load_model(path)
76
+ print("CityAesthetics: Pipeline init ok") # debug
77
+
78
+ def __call__(self, raw):
79
+ emb = self.get_clip_emb(raw)
80
+ out = {}
81
+ for name, model in self.models.items():
82
+ pred = model(emb)
83
+ out[name] = self.get_model_pred(model, emb)
84
+ return out
85
+
86
+ class CityClassifierPipeline:
87
+ """
88
+ Demo model pipeline for [image=>label] prediction
89
+ Accepts a single model path and (optionally) a JSON file on initialization.
90
+ Resulting object can be called directly with a PIL image as the input
91
+ Returns a single float value with the predicted score [0.0;1.0].
92
+ """
93
+ clip_ver = "openai/clip-vit-large-patch14"
94
+ def __init__(self, model_path, config_path=None, device="cpu", clip_dtype=torch.float32):
95
+ self.device = device
96
+ self.clip_dtype = clip_dtype
97
+ self._init_clip()
98
+
99
+ self.labels, model_args = self._load_config(config_path)
100
+ self.model = self._load_model(model_path, model_args)
101
+
102
+ print("CityClassifier: Pipeline init ok") # debug
103
+
104
+ def __call__(self, raw, default=True, tiling=True, tile_strat="mean"):
105
+ emb = self.get_clip_emb(raw, tiling=tiling)
106
+ pred = self.get_model_pred(self.model, emb)
107
+ return self.format_pred(
108
+ pred,
109
+ labels = self.labels,
110
+ drop = [] if default else [0],
111
+ ts = tile_strat if tiling else "raw",
112
+ )
113
+
114
+ def format_pred(self, pred, labels, drop=[], ts="mean"):
115
+ # recombine strategy
116
+ if ts == "mean" : vp = lambda x: float(torch.mean(x))
117
+ elif ts == "median": vp = lambda x: float(torch.median(x))
118
+ elif ts == "max" : vp = lambda x: float(torch.max(x))
119
+ elif ts == "min" : vp = lambda x: float(torch.min(x))
120
+ elif ts == "raw" : vp = lambda x: float(x)
121
+ else: raise NotImplementedError(f"CityClassifier: Invalid combine strategy '{ts}'!")
122
+ # combine pred w/ labels
123
+ out = {}
124
+ for k in range(len(pred)):
125
+ if k in drop: continue
126
+ key = labels.get(str(k), str(k))
127
+ out[key] = vp(pred[k])
128
+ return out
129
+
130
+ def get_model_pred(self, model, emb):
131
+ with torch.no_grad():
132
+ pred = model(emb)
133
+ pred = pred.detach().cpu()
134
+ return [pred[:, x] for x in range(pred.shape[1])] # split
135
+
136
+ def get_clip_emb(self, raw, tiling=False):
137
+ if tiling and min(raw.size)>512:
138
+ if max(raw.size)>1536:
139
+ raw = TF.functional.resize(raw, 1536)
140
+ raw = TF.functional.five_crop(raw, 512)
141
+ img = self.proc(
142
+ images = raw,
143
+ return_tensors = "pt"
144
+ )["pixel_values"].to(self.clip_dtype).to(self.device)
145
+ with torch.no_grad():
146
+ emb = self.clip(pixel_values=img)
147
+ return emb["image_embeds"].detach().to(torch.float32)
148
+
149
+ def _init_clip(self):
150
+ self.proc = CLIPImageProcessor.from_pretrained(self.clip_ver)
151
+ self.clip = CLIPVisionModelWithProjection.from_pretrained(
152
+ self.clip_ver,
153
+ device_map = self.device,
154
+ torch_dtype = self.clip_dtype,
155
+ )
156
+
157
+ def _load_model(self, path, args=None):
158
+ sd = load_file(path)
159
+ assert tuple(sd["up.0.weight"].shape) == (1024, 768) # only allow CLIP ver
160
+ args = args or { # infer from model
161
+ "outputs" : int(sd["down.5.bias"].shape[0])
162
+ }
163
+ model = PredictorModel(**args)
164
+ model.eval()
165
+ model.load_state_dict(sd)
166
+ model.to(self.device)
167
+ return model
168
+
169
+ def _load_config(self, path):
170
+ if not path or not os.path.isfile(path):
171
+ return ({},None)
172
+
173
+ with open(path) as f:
174
+ data = json.loads(f.read())
175
+ return (
176
+ data.get("labels", {}),
177
+ data.get("model_params", {}),
178
+ )
179
+
180
+ class CityClassifierMultiModelPipeline(CityClassifierPipeline):
181
+ """
182
+ Demo model pipeline for [image=>label] prediction
183
+ Accepts a list of model paths on initialization.
184
+ A matching list of JSON files can also be passed in the same order.
185
+ Resulting object can be called directly with a PIL image as the input
186
+ Returns a single float value with the predicted score [0.0;1.0].
187
+ """
188
+ def __init__(self, model_paths, config_paths=[], device="cpu", clip_dtype=torch.float32):
189
+ self.device = device
190
+ self.clip_dtype = clip_dtype
191
+ self._init_clip()
192
+ self.models = {}
193
+ self.labels = {}
194
+ assert len(model_paths) == len(config_paths) or not config_paths, "CityClassifier: Model and config paths must match!"
195
+ for k in range(len(model_paths)):
196
+ name = os.path.splitext(os.path.basename(model_paths[k]))[0] # TODO: read from config
197
+ self.labels[name], model_args = self._load_config(config_paths[k] if config_paths else None)
198
+ self.models[name] = self._load_model(model_paths[k], model_args)
199
+
200
+ print("CityClassifier: Pipeline init ok") # debug
201
+
202
+ def __call__(self, raw, default=True, tiling=True, tile_strat="mean"):
203
+ emb = self.get_clip_emb(raw, tiling=tiling)
204
+ out = {}
205
+ for name, model in self.models.items():
206
+ pred = self.get_model_pred(model, emb)
207
+ out[name] = self.format_pred(
208
+ pred,
209
+ labels = self.labels[name],
210
+ drop = [] if default else [0],
211
+ ts = tile_strat if tiling else "raw",
212
+ )
213
+ if len(out.values()) == 1: return list(out.values())[0] # GRADIO HOTFIX
214
+ return list(out.values())
215
+
216
+ def get_model_path(name, repo, token=True, extension="safetensors", local=False):
217
+ """
218
+ Returns local model path or falls back to HF hub if required.
219
+ """
220
+ fname = f"{name}.{extension}"
221
+
222
+ # local path: [models/AesPred-Anime-v1.8.safetensors]
223
+ path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"models")
224
+ if os.path.isfile(os.path.join(path, fname)):
225
+ print(f"Using local model for '{fname}'")
226
+ return os.path.join(path, fname)
227
+
228
+ if local: raise OSError(f"Can't find local model '{fname}'!")
229
+
230
+ # huggingface hub fallback
231
+ print(f"Using HF Hub model for '{fname}'")
232
+ return str(hf_hub_download(
233
+ token = token,
234
+ repo_id = repo,
235
+ filename = fname,
236
+ ))
model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ResBlock(nn.Module):
5
+ """Linear block with residuals"""
6
+ def __init__(self, ch):
7
+ super().__init__()
8
+ self.join = nn.ReLU()
9
+ self.long = nn.Sequential(
10
+ nn.Linear(ch, ch),
11
+ nn.LeakyReLU(0.1),
12
+ nn.Linear(ch, ch),
13
+ nn.LeakyReLU(0.1),
14
+ nn.Linear(ch, ch),
15
+ )
16
+ def forward(self, x):
17
+ return self.join(self.long(x) + x)
18
+
19
+ class PredictorModel(nn.Module):
20
+ """Main predictor class"""
21
+ def __init__(self, features=768, outputs=1, hidden=1024):
22
+ super().__init__()
23
+ self.features = features
24
+ self.outputs = outputs
25
+ self.hidden = hidden
26
+ self.up = nn.Sequential(
27
+ nn.Linear(self.features, self.hidden),
28
+ ResBlock(ch=self.hidden),
29
+ )
30
+ self.down = nn.Sequential(
31
+ nn.Linear(self.hidden, 128),
32
+ nn.Linear(128, 64),
33
+ nn.Dropout(0.1),
34
+ nn.LeakyReLU(),
35
+ nn.Linear(64, 32),
36
+ nn.Linear(32, self.outputs),
37
+ )
38
+ self.out = nn.Softmax(dim=1) if self.outputs > 1 else nn.Tanh()
39
+ def forward(self, x):
40
+ y = self.up(x)
41
+ z = self.down(y)
42
+ if self.outputs > 1:
43
+ return self.out(z)
44
+ else:
45
+ return (self.out(z)+1.0)/2.0
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.1.0
2
+ accelerate==0.24.1
3
+ safetensors==0.4.0
4
+ transformers==4.35.0