SmilingWolf commited on
Commit
999d8f3
1 Parent(s): b079c7b

Update app.py

Browse files

Add newly released MOAT model

Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -19,6 +19,7 @@ from Utils import dbimutils
19
  TITLE = "WaifuDiffusion v1.4 Tags"
20
  DESCRIPTION = """
21
  Demo for:
 
22
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
23
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
24
  - [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
@@ -35,6 +36,7 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
35
  """
36
 
37
  HF_TOKEN = os.environ["HF_TOKEN"]
 
38
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
39
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
40
  CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
@@ -63,7 +65,9 @@ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
63
  def change_model(model_name):
64
  global loaded_models
65
 
66
- if model_name == "SwinV2":
 
 
67
  model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
68
  elif model_name == "ConvNext":
69
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
@@ -78,7 +82,7 @@ def change_model(model_name):
78
 
79
  def load_labels() -> list[str]:
80
  path = huggingface_hub.hf_hub_download(
81
- CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
82
  )
83
  df = pd.read_csv(path)
84
 
@@ -213,11 +217,17 @@ def predict(
213
 
214
  def main():
215
  global loaded_models
216
- loaded_models = {"SwinV2": None, "ConvNext": None, "ConvNextV2": None, "ViT": None}
 
 
 
 
 
 
217
 
218
  args = parse_args()
219
 
220
- change_model("ConvNextV2")
221
 
222
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
223
 
@@ -233,7 +243,11 @@ def main():
233
  fn=func,
234
  inputs=[
235
  gr.Image(type="pil", label="Input"),
236
- gr.Radio(["SwinV2", "ConvNext", "ConvNextV2", "ViT"], value="ConvNextV2", label="Model"),
 
 
 
 
237
  gr.Slider(
238
  0,
239
  1,
@@ -257,7 +271,7 @@ def main():
257
  gr.Label(label="Output (tags)"),
258
  gr.HTML(),
259
  ],
260
- examples=[["power.jpg", "ConvNextV2", 0.35, 0.85]],
261
  title=TITLE,
262
  description=DESCRIPTION,
263
  allow_flagging="never",
 
19
  TITLE = "WaifuDiffusion v1.4 Tags"
20
  DESCRIPTION = """
21
  Demo for:
22
+ - [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2)
23
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
24
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
25
  - [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
 
36
  """
37
 
38
  HF_TOKEN = os.environ["HF_TOKEN"]
39
+ MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
40
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
41
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
42
  CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
 
65
  def change_model(model_name):
66
  global loaded_models
67
 
68
+ if model_name == "MOAT":
69
+ model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME)
70
+ elif model_name == "SwinV2":
71
  model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
72
  elif model_name == "ConvNext":
73
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
 
82
 
83
  def load_labels() -> list[str]:
84
  path = huggingface_hub.hf_hub_download(
85
+ MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
86
  )
87
  df = pd.read_csv(path)
88
 
 
217
 
218
  def main():
219
  global loaded_models
220
+ loaded_models = {
221
+ "MOAT": None,
222
+ "SwinV2": None,
223
+ "ConvNext": None,
224
+ "ConvNextV2": None,
225
+ "ViT": None,
226
+ }
227
 
228
  args = parse_args()
229
 
230
+ change_model("MOAT")
231
 
232
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
233
 
 
243
  fn=func,
244
  inputs=[
245
  gr.Image(type="pil", label="Input"),
246
+ gr.Radio(
247
+ ["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"],
248
+ value="MOAT",
249
+ label="Model",
250
+ ),
251
  gr.Slider(
252
  0,
253
  1,
 
271
  gr.Label(label="Output (tags)"),
272
  gr.HTML(),
273
  ],
274
+ examples=[["power.jpg", "MOAT", 0.35, 0.85]],
275
  title=TITLE,
276
  description=DESCRIPTION,
277
  allow_flagging="never",