SmilingWolf commited on
Commit
5c69e57
1 Parent(s): 999d8f3

Update app

Browse files
Files changed (4) hide show
  1. README.md +2 -3
  2. Utils/dbimutils.py +0 -54
  3. app.py +288 -234
  4. requirements.txt +0 -2
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: WaifuDiffusion v1.4 Tags
3
  emoji: 💬
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.16.2
8
  app_file: app.py
9
  pinned: false
10
- duplicated_from: NoCrypt/DeepDanbooru_string
11
  ---
12
 
13
  # Configuration
 
1
  ---
2
+ title: WaifuDiffusion Tagger
3
  emoji: 💬
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.20.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  # Configuration
Utils/dbimutils.py DELETED
@@ -1,54 +0,0 @@
1
- # DanBooru IMage Utility functions
2
-
3
- import cv2
4
- import numpy as np
5
- from PIL import Image
6
-
7
-
8
- def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
9
- if img.endswith(".gif"):
10
- img = Image.open(img)
11
- img = img.convert("RGB")
12
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
13
- else:
14
- img = cv2.imread(img, flag)
15
- return img
16
-
17
-
18
- def smart_24bit(img):
19
- if img.dtype is np.dtype(np.uint16):
20
- img = (img / 257).astype(np.uint8)
21
-
22
- if len(img.shape) == 2:
23
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
24
- elif img.shape[2] == 4:
25
- trans_mask = img[:, :, 3] == 0
26
- img[trans_mask] = [255, 255, 255, 255]
27
- img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
28
- return img
29
-
30
-
31
- def make_square(img, target_size):
32
- old_size = img.shape[:2]
33
- desired_size = max(old_size)
34
- desired_size = max(desired_size, target_size)
35
-
36
- delta_w = desired_size - old_size[1]
37
- delta_h = desired_size - old_size[0]
38
- top, bottom = delta_h // 2, delta_h - (delta_h // 2)
39
- left, right = delta_w // 2, delta_w - (delta_w // 2)
40
-
41
- color = [255, 255, 255]
42
- new_im = cv2.copyMakeBorder(
43
- img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
44
- )
45
- return new_im
46
-
47
-
48
- def smart_resize(img, size):
49
- # Assumes the image has already gone through make_square
50
- if img.shape[0] > size:
51
- img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
52
- elif img.shape[0] < size:
53
- img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
54
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,8 +1,4 @@
1
- from __future__ import annotations
2
-
3
  import argparse
4
- import functools
5
- import html
6
  import os
7
 
8
  import gradio as gr
@@ -10,40 +6,56 @@ import huggingface_hub
10
  import numpy as np
11
  import onnxruntime as rt
12
  import pandas as pd
13
- import piexif
14
- import piexif.helper
15
- import PIL.Image
16
-
17
- from Utils import dbimutils
18
 
19
- TITLE = "WaifuDiffusion v1.4 Tags"
20
  DESCRIPTION = """
21
- Demo for:
22
- - [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2)
23
- - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
24
- - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
25
- - [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
26
- - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
27
-
28
- Includes "ready to copy" prompt and a prompt analyzer.
29
-
30
- Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
31
- Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
32
-
33
- PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
34
 
35
  Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
36
  """
37
 
38
  HF_TOKEN = os.environ["HF_TOKEN"]
39
- MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
40
- SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
41
- CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
42
- CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
43
- VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
 
 
 
 
 
 
 
 
 
44
  MODEL_FILENAME = "model.onnx"
45
  LABEL_FILENAME = "selected_tags.csv"
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def parse_args() -> argparse.Namespace:
49
  parser = argparse.ArgumentParser()
@@ -54,231 +66,273 @@ def parse_args() -> argparse.Namespace:
54
  return parser.parse_args()
55
 
56
 
57
- def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
58
- path = huggingface_hub.hf_hub_download(
59
- model_repo, model_filename, use_auth_token=HF_TOKEN
 
60
  )
61
- model = rt.InferenceSession(path)
62
- return model
63
 
 
 
 
 
64
 
65
- def change_model(model_name):
66
- global loaded_models
67
 
68
- if model_name == "MOAT":
69
- model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME)
70
- elif model_name == "SwinV2":
71
- model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
72
- elif model_name == "ConvNext":
73
- model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
74
- elif model_name == "ConvNextV2":
75
- model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
76
- elif model_name == "ViT":
77
- model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- loaded_models[model_name] = model
80
- return loaded_models[model_name]
 
81
 
 
82
 
83
- def load_labels() -> list[str]:
84
- path = huggingface_hub.hf_hub_download(
85
- MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
86
- )
87
- df = pd.read_csv(path)
88
 
89
- tag_names = df["name"].tolist()
90
- rating_indexes = list(np.where(df["category"] == 9)[0])
91
- general_indexes = list(np.where(df["category"] == 0)[0])
92
- character_indexes = list(np.where(df["category"] == 4)[0])
93
- return tag_names, rating_indexes, general_indexes, character_indexes
94
 
 
 
 
95
 
96
- def plaintext_to_html(text):
97
- text = (
98
- "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>"
99
- )
100
- return text
101
-
102
-
103
- def predict(
104
- image: PIL.Image.Image,
105
- model_name: str,
106
- general_threshold: float,
107
- character_threshold: float,
108
- tag_names: list[str],
109
- rating_indexes: list[np.int64],
110
- general_indexes: list[np.int64],
111
- character_indexes: list[np.int64],
112
- ):
113
- global loaded_models
114
-
115
- rawimage = image
116
-
117
- model = loaded_models[model_name]
118
- if model is None:
119
- model = change_model(model_name)
120
-
121
- _, height, width, _ = model.get_inputs()[0].shape
122
-
123
- # Alpha to white
124
- image = image.convert("RGBA")
125
- new_image = PIL.Image.new("RGBA", image.size, "WHITE")
126
- new_image.paste(image, mask=image)
127
- image = new_image.convert("RGB")
128
- image = np.asarray(image)
129
-
130
- # PIL RGB to OpenCV BGR
131
- image = image[:, :, ::-1]
132
-
133
- image = dbimutils.make_square(image, height)
134
- image = dbimutils.smart_resize(image, height)
135
- image = image.astype(np.float32)
136
- image = np.expand_dims(image, 0)
137
-
138
- input_name = model.get_inputs()[0].name
139
- label_name = model.get_outputs()[0].name
140
- probs = model.run([label_name], {input_name: image})[0]
141
-
142
- labels = list(zip(tag_names, probs[0].astype(float)))
143
-
144
- # First 4 labels are actually ratings: pick one with argmax
145
- ratings_names = [labels[i] for i in rating_indexes]
146
- rating = dict(ratings_names)
147
-
148
- # Then we have general tags: pick any where prediction confidence > threshold
149
- general_names = [labels[i] for i in general_indexes]
150
- general_res = [x for x in general_names if x[1] > general_threshold]
151
- general_res = dict(general_res)
152
-
153
- # Everything else is characters: pick any where prediction confidence > threshold
154
- character_names = [labels[i] for i in character_indexes]
155
- character_res = [x for x in character_names if x[1] > character_threshold]
156
- character_res = dict(character_res)
157
-
158
- b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
159
- a = (
160
- ", ".join(list(b.keys()))
161
- .replace("_", " ")
162
- .replace("(", "\(")
163
- .replace(")", "\)")
164
- )
165
- c = ", ".join(list(b.keys()))
166
-
167
- items = rawimage.info
168
- geninfo = ""
169
-
170
- if "exif" in rawimage.info:
171
- exif = piexif.load(rawimage.info["exif"])
172
- exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
173
- try:
174
- exif_comment = piexif.helper.UserComment.load(exif_comment)
175
- except ValueError:
176
- exif_comment = exif_comment.decode("utf8", errors="ignore")
177
-
178
- items["exif comment"] = exif_comment
179
- geninfo = exif_comment
180
-
181
- for field in [
182
- "jfif",
183
- "jfif_version",
184
- "jfif_unit",
185
- "jfif_density",
186
- "dpi",
187
- "exif",
188
- "loop",
189
- "background",
190
- "timestamp",
191
- "duration",
192
- ]:
193
- items.pop(field, None)
194
-
195
- geninfo = items.get("parameters", geninfo)
196
-
197
- info = f"""
198
- <p><h4>PNG Info</h4></p>
199
- """
200
- for key, text in items.items():
201
- info += (
202
- f"""
203
- <div>
204
- <p><b>{plaintext_to_html(str(key))}</b></p>
205
- <p>{plaintext_to_html(str(text))}</p>
206
- </div>
207
- """.strip()
208
- + "\n"
209
- )
210
 
211
- if len(info) == 0:
212
- message = "Nothing found in the image."
213
- info = f"<div><p>{message}<p></div>"
214
 
215
- return (a, c, rating, character_res, general_res, info)
 
 
216
 
 
 
 
 
 
217
 
218
- def main():
219
- global loaded_models
220
- loaded_models = {
221
- "MOAT": None,
222
- "SwinV2": None,
223
- "ConvNext": None,
224
- "ConvNextV2": None,
225
- "ViT": None,
226
- }
227
 
228
- args = parse_args()
 
 
 
 
 
229
 
230
- change_model("MOAT")
 
231
 
232
- tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
 
233
 
234
- func = functools.partial(
235
- predict,
236
- tag_names=tag_names,
237
- rating_indexes=rating_indexes,
238
- general_indexes=general_indexes,
239
- character_indexes=character_indexes,
240
- )
241
 
242
- gr.Interface(
243
- fn=func,
244
- inputs=[
245
- gr.Image(type="pil", label="Input"),
246
- gr.Radio(
247
- ["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"],
248
- value="MOAT",
249
- label="Model",
250
- ),
251
- gr.Slider(
252
- 0,
253
- 1,
254
- step=args.score_slider_step,
255
- value=args.score_general_threshold,
256
- label="General Tags Threshold",
257
- ),
258
- gr.Slider(
259
- 0,
260
- 1,
261
- step=args.score_slider_step,
262
- value=args.score_character_threshold,
263
- label="Character Tags Threshold",
264
- ),
265
- ],
266
- outputs=[
267
- gr.Textbox(label="Output (string)"),
268
- gr.Textbox(label="Output (raw string)"),
269
- gr.Label(label="Rating"),
270
- gr.Label(label="Output (characters)"),
271
- gr.Label(label="Output (tags)"),
272
- gr.HTML(),
273
- ],
274
- examples=[["power.jpg", "MOAT", 0.35, 0.85]],
275
- title=TITLE,
276
- description=DESCRIPTION,
277
- allow_flagging="never",
278
- ).launch(
279
- enable_queue=True,
280
- share=args.share,
281
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
 
284
  if __name__ == "__main__":
 
 
 
1
  import argparse
 
 
2
  import os
3
 
4
  import gradio as gr
 
6
  import numpy as np
7
  import onnxruntime as rt
8
  import pandas as pd
9
+ from PIL import Image
 
 
 
 
10
 
11
+ TITLE = "WaifuDiffusion Tagger"
12
  DESCRIPTION = """
13
+ Demo for the WaifuDiffusion tagger models
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
16
  """
17
 
18
  HF_TOKEN = os.environ["HF_TOKEN"]
19
+
20
+ # Dataset v3 series of models:
21
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
22
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
23
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
24
+
25
+ # Dataset v2 series of models:
26
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
+
32
+ # Files to download from the repos
33
  MODEL_FILENAME = "model.onnx"
34
  LABEL_FILENAME = "selected_tags.csv"
35
 
36
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
37
+ kaomojis = [
38
+ "0_0",
39
+ "(o)_(o)",
40
+ "+_+",
41
+ "+_-",
42
+ "._.",
43
+ "<o>_<o>",
44
+ "<|>_<|>",
45
+ "=_=",
46
+ ">_<",
47
+ "3_3",
48
+ "6_9",
49
+ ">_o",
50
+ "@_@",
51
+ "^_^",
52
+ "o_o",
53
+ "u_u",
54
+ "x_x",
55
+ "|_|",
56
+ "||_||",
57
+ ]
58
+
59
 
60
  def parse_args() -> argparse.Namespace:
61
  parser = argparse.ArgumentParser()
 
66
  return parser.parse_args()
67
 
68
 
69
+ def load_labels(dataframe) -> list[str]:
70
+ name_series = dataframe["name"]
71
+ name_series = name_series.map(
72
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
73
  )
74
+ tag_names = name_series.tolist()
 
75
 
76
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
77
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
78
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
79
+ return tag_names, rating_indexes, general_indexes, character_indexes
80
 
 
 
81
 
82
+ def mcut_threshold(probs):
83
+ """
84
+ Maximum Cut Thresholding (MCut)
85
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
86
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
87
+ (pp. 172-183).
88
+ """
89
+ sorted_probs = probs[probs.argsort()[::-1]]
90
+ difs = sorted_probs[:-1] - sorted_probs[1:]
91
+ t = difs.argmax()
92
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
93
+ return thresh
94
+
95
+
96
+ class Predictor:
97
+ def __init__(self):
98
+ self.model_target_size = None
99
+ self.last_loaded_repo = None
100
+
101
+ def download_model(self, model_repo):
102
+ csv_path = huggingface_hub.hf_hub_download(
103
+ model_repo,
104
+ LABEL_FILENAME,
105
+ use_auth_token=HF_TOKEN,
106
+ )
107
+ model_path = huggingface_hub.hf_hub_download(
108
+ model_repo,
109
+ MODEL_FILENAME,
110
+ use_auth_token=HF_TOKEN,
111
+ )
112
+ return csv_path, model_path
113
 
114
+ def load_model(self, model_repo):
115
+ if model_repo == self.last_loaded_repo:
116
+ return
117
 
118
+ csv_path, model_path = self.download_model(model_repo)
119
 
120
+ tags_df = pd.read_csv(csv_path)
121
+ sep_tags = load_labels(tags_df)
 
 
 
122
 
123
+ self.tag_names = sep_tags[0]
124
+ self.rating_indexes = sep_tags[1]
125
+ self.general_indexes = sep_tags[2]
126
+ self.character_indexes = sep_tags[3]
 
127
 
128
+ model = rt.InferenceSession(model_path)
129
+ _, height, width, _ = model.get_inputs()[0].shape
130
+ self.model_target_size = height
131
 
132
+ self.last_loaded_repo = model_path
133
+ self.model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ def prepare_image(self, image):
136
+ target_size = self.model_target_size
 
137
 
138
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
139
+ canvas.alpha_composite(image)
140
+ image = canvas.convert("RGB")
141
 
142
+ # Pad image to square
143
+ image_shape = image.size
144
+ max_dim = max(image_shape)
145
+ pad_left = (max_dim - image_shape[0]) // 2
146
+ pad_top = (max_dim - image_shape[1]) // 2
147
 
148
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
149
+ padded_image.paste(image, (pad_left, pad_top))
 
 
 
 
 
 
 
150
 
151
+ # Resize
152
+ if max_dim != target_size:
153
+ padded_image = padded_image.resize(
154
+ (target_size, target_size),
155
+ Image.BICUBIC,
156
+ )
157
 
158
+ # Convert to numpy array
159
+ image_array = np.asarray(padded_image, dtype=np.float32)
160
 
161
+ # Convert PIL-native RGB to BGR
162
+ image_array = image_array[:, :, ::-1]
163
 
164
+ return np.expand_dims(image_array, axis=0)
 
 
 
 
 
 
165
 
166
+ def predict(
167
+ self,
168
+ image,
169
+ model_repo,
170
+ general_thresh,
171
+ general_mcut_enabled,
172
+ character_thresh,
173
+ character_mcut_enabled,
174
+ ):
175
+ self.load_model(model_repo)
176
+
177
+ image = self.prepare_image(image)
178
+
179
+ input_name = self.model.get_inputs()[0].name
180
+ label_name = self.model.get_outputs()[0].name
181
+ preds = self.model.run([label_name], {input_name: image})[0]
182
+
183
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
184
+
185
+ # First 4 labels are actually ratings: pick one with argmax
186
+ ratings_names = [labels[i] for i in self.rating_indexes]
187
+ rating = dict(ratings_names)
188
+
189
+ # Then we have general tags: pick any where prediction confidence > threshold
190
+ general_names = [labels[i] for i in self.general_indexes]
191
+
192
+ if general_mcut_enabled:
193
+ general_probs = np.array([x[1] for x in general_names])
194
+ general_thresh = mcut_threshold(general_probs)
195
+
196
+ general_res = [x for x in general_names if x[1] > general_thresh]
197
+ general_res = dict(general_res)
198
+
199
+ # Everything else is characters: pick any where prediction confidence > threshold
200
+ character_names = [labels[i] for i in self.character_indexes]
201
+
202
+ if character_mcut_enabled:
203
+ character_probs = np.array([x[1] for x in character_names])
204
+ character_thresh = mcut_threshold(character_probs)
205
+ character_thresh = max(0.15, character_thresh)
206
+
207
+ character_res = [x for x in character_names if x[1] > character_thresh]
208
+ character_res = dict(character_res)
209
+
210
+ sorted_general_strings = sorted(
211
+ general_res.items(),
212
+ key=lambda x: x[1],
213
+ reverse=True,
214
+ )
215
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
216
+ sorted_general_strings = (
217
+ ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
218
+ )
219
+
220
+ return sorted_general_strings, rating, character_res, general_res
221
+
222
+
223
+ def main():
224
+ args = parse_args()
225
+
226
+ predictor = Predictor()
227
+
228
+ dropdown_list = [
229
+ SWINV2_MODEL_DSV3_REPO,
230
+ CONV_MODEL_DSV3_REPO,
231
+ VIT_MODEL_DSV3_REPO,
232
+ MOAT_MODEL_DSV2_REPO,
233
+ SWIN_MODEL_DSV2_REPO,
234
+ CONV_MODEL_DSV2_REPO,
235
+ CONV2_MODEL_DSV2_REPO,
236
+ VIT_MODEL_DSV2_REPO,
237
+ ]
238
+
239
+ with gr.Blocks(title=TITLE) as demo:
240
+ with gr.Column():
241
+ gr.Markdown(
242
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
243
+ )
244
+ gr.Markdown(value=DESCRIPTION)
245
+ with gr.Row():
246
+ with gr.Column(variant="panel"):
247
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
248
+ model_repo = gr.Dropdown(
249
+ dropdown_list,
250
+ value=VIT_MODEL_DSV3_REPO,
251
+ label="Model",
252
+ )
253
+ with gr.Row():
254
+ general_thresh = gr.Slider(
255
+ 0,
256
+ 1,
257
+ step=args.score_slider_step,
258
+ value=args.score_general_threshold,
259
+ label="General Tags Threshold",
260
+ scale=3,
261
+ )
262
+ general_mcut_enabled = gr.Checkbox(
263
+ value=False,
264
+ label="Use MCut threshold",
265
+ scale=1,
266
+ )
267
+ with gr.Row():
268
+ character_thresh = gr.Slider(
269
+ 0,
270
+ 1,
271
+ step=args.score_slider_step,
272
+ value=args.score_character_threshold,
273
+ label="Character Tags Threshold",
274
+ scale=3,
275
+ )
276
+ character_mcut_enabled = gr.Checkbox(
277
+ value=False,
278
+ label="Use MCut threshold",
279
+ scale=1,
280
+ )
281
+ with gr.Row():
282
+ clear = gr.ClearButton(
283
+ components=[
284
+ image,
285
+ model_repo,
286
+ general_thresh,
287
+ general_mcut_enabled,
288
+ character_thresh,
289
+ character_mcut_enabled,
290
+ ],
291
+ variant="secondary",
292
+ size="lg",
293
+ )
294
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
295
+ with gr.Column(variant="panel"):
296
+ sorted_general_strings = gr.Textbox(label="Output (string)")
297
+ rating = gr.Label(label="Rating")
298
+ character_res = gr.Label(label="Output (characters)")
299
+ general_res = gr.Label(label="Output (tags)")
300
+ clear.add(
301
+ [
302
+ sorted_general_strings,
303
+ rating,
304
+ character_res,
305
+ general_res,
306
+ ]
307
+ )
308
+
309
+ submit.click(
310
+ predictor.predict,
311
+ inputs=[
312
+ image,
313
+ model_repo,
314
+ general_thresh,
315
+ general_mcut_enabled,
316
+ character_thresh,
317
+ character_mcut_enabled,
318
+ ],
319
+ outputs=[sorted_general_strings, rating, character_res, general_res],
320
+ )
321
+
322
+ gr.Examples(
323
+ [["power.jpg", VIT_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
324
+ inputs=[
325
+ image,
326
+ model_repo,
327
+ general_thresh,
328
+ general_mcut_enabled,
329
+ character_thresh,
330
+ character_mcut_enabled,
331
+ ],
332
+ )
333
+
334
+ demo.queue(max_size=10)
335
+ demo.launch()
336
 
337
 
338
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
  pillow>=9.0.0
2
- piexif>=1.1.3
3
  onnxruntime>=1.12.0
4
- opencv-python
5
  huggingface-hub
 
1
  pillow>=9.0.0
 
2
  onnxruntime>=1.12.0
 
3
  huggingface-hub