SmilingWolf commited on
Commit
f6dbb10
1 Parent(s): f56e0f7

Add support for model selection

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +17 -9
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.6
8
  app_file: app.py
9
  pinned: false
10
  duplicated_from: NoCrypt/DeepDanbooru_string
 
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.13
8
  app_file: app.py
9
  pinned: false
10
  duplicated_from: NoCrypt/DeepDanbooru_string
app.py CHANGED
@@ -20,7 +20,7 @@ from Utils import dbimutils
20
 
21
  TITLE = "WaifuDiffusion v1.4 Tags"
22
  DESCRIPTION = """
23
- Demo for [SmilingWolf/wd-v1-4-vit-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) with "ready to copy" prompt and a prompt analyzer.
24
 
25
  Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
26
  Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
@@ -31,7 +31,8 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
31
  """
32
 
33
  HF_TOKEN = os.environ["HF_TOKEN"]
34
- MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger"
 
35
  MODEL_FILENAME = "model.onnx"
36
  LABEL_FILENAME = "selected_tags.csv"
37
 
@@ -44,9 +45,9 @@ def parse_args() -> argparse.Namespace:
44
  return parser.parse_args()
45
 
46
 
47
- def load_model() -> rt.InferenceSession:
48
  path = huggingface_hub.hf_hub_download(
49
- MODEL_REPO, MODEL_FILENAME, use_auth_token=HF_TOKEN
50
  )
51
  model = rt.InferenceSession(path)
52
  return model
@@ -54,7 +55,7 @@ def load_model() -> rt.InferenceSession:
54
 
55
  def load_labels() -> list[str]:
56
  path = huggingface_hub.hf_hub_download(
57
- MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
58
  )
59
  df = pd.read_csv(path)["name"].tolist()
60
  return df
@@ -69,11 +70,14 @@ def plaintext_to_html(text):
69
 
70
  def predict(
71
  image: PIL.Image.Image,
 
72
  score_threshold: float,
73
- model: rt.InferenceSession,
74
  labels: list[str],
75
  ):
76
  rawimage = image
 
 
77
  _, height, width, _ = model.get_inputs()[0].shape
78
 
79
  # Alpha to white
@@ -168,15 +172,19 @@ def predict(
168
 
169
  def main():
170
  args = parse_args()
171
- model = load_model()
 
172
  labels = load_labels()
173
 
174
- func = functools.partial(predict, model=model, labels=labels)
 
 
175
 
176
  gr.Interface(
177
  fn=func,
178
  inputs=[
179
  gr.Image(type="pil", label="Input"),
 
180
  gr.Slider(
181
  0,
182
  1,
@@ -192,7 +200,7 @@ def main():
192
  gr.Label(label="Output (label)"),
193
  gr.HTML(),
194
  ],
195
- examples=[["power.jpg", 0.5]],
196
  title=TITLE,
197
  description=DESCRIPTION,
198
  allow_flagging="never",
 
20
 
21
  TITLE = "WaifuDiffusion v1.4 Tags"
22
  DESCRIPTION = """
23
+ Demo for [SmilingWolf/wd-v1-4-vit-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) and [SmilingWolf/wd-v1-4-convnext-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger) with "ready to copy" prompt and a prompt analyzer.
24
 
25
  Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
26
  Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
 
31
  """
32
 
33
  HF_TOKEN = os.environ["HF_TOKEN"]
34
+ VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger"
35
+ CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger"
36
  MODEL_FILENAME = "model.onnx"
37
  LABEL_FILENAME = "selected_tags.csv"
38
 
 
45
  return parser.parse_args()
46
 
47
 
48
+ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
49
  path = huggingface_hub.hf_hub_download(
50
+ model_repo, model_filename, use_auth_token=HF_TOKEN
51
  )
52
  model = rt.InferenceSession(path)
53
  return model
 
55
 
56
  def load_labels() -> list[str]:
57
  path = huggingface_hub.hf_hub_download(
58
+ VIT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
59
  )
60
  df = pd.read_csv(path)["name"].tolist()
61
  return df
 
70
 
71
  def predict(
72
  image: PIL.Image.Image,
73
+ selected_model: str,
74
  score_threshold: float,
75
+ models: dict,
76
  labels: list[str],
77
  ):
78
  rawimage = image
79
+
80
+ model = models[selected_model]
81
  _, height, width, _ = model.get_inputs()[0].shape
82
 
83
  # Alpha to white
 
172
 
173
  def main():
174
  args = parse_args()
175
+ vit_model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
176
+ conv_model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
177
  labels = load_labels()
178
 
179
+ models = {"ViT": vit_model, "ConvNext": conv_model}
180
+
181
+ func = functools.partial(predict, models=models, labels=labels)
182
 
183
  gr.Interface(
184
  fn=func,
185
  inputs=[
186
  gr.Image(type="pil", label="Input"),
187
+ gr.Radio(["ViT", "ConvNext"], label="Model"),
188
  gr.Slider(
189
  0,
190
  1,
 
200
  gr.Label(label="Output (label)"),
201
  gr.HTML(),
202
  ],
203
+ examples=[["power.jpg", "ViT", 0.5]],
204
  title=TITLE,
205
  description=DESCRIPTION,
206
  allow_flagging="never",