narugo1992 commited on
Commit
d9e6838
1 Parent(s): 90a66ed

dev(narugo): add text export feature on gradio

Browse files
Files changed (1) hide show
  1. app.py +29 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Mapping, Tuple, Dict
2
 
3
  import cv2
@@ -118,17 +119,37 @@ WAIFU_MODELS: Mapping[str, WaifuDiffusionInterrogator] = {
118
  repo='SmilingWolf/wd-v1-4-convnext-tagger'
119
  ),
120
  }
 
121
 
122
 
123
- def image_to_wd14_tags(image: Image.Image, model_name: str, threshold: float):
 
 
124
  model = WAIFU_MODELS[model_name]
125
  ratings, tags = model.interrogate(image)
126
 
127
- return ratings, {
128
  tag: score for tag, score in tags.items()
129
  if score >= threshold
130
  }
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  if __name__ == '__main__':
134
  interface = gr.Interface(
@@ -136,10 +157,15 @@ if __name__ == '__main__':
136
  inputs=[
137
  gr.Image(type='pil', label='Original Image'),
138
  gr.Radio(list(WAIFU_MODELS.keys()), value='wd14-vit', label='Waifu Model'),
139
- gr.Slider(0.0, 1.0, value=0.5, label='Tagging Confidence Threshold')
 
 
 
 
140
  ],
141
  outputs=[
142
  gr.Label(label='Ratings'),
 
143
  gr.Label(label='Tags'),
144
  ],
145
  interpretation="default"
 
1
+ import re
2
  from typing import Mapping, Tuple, Dict
3
 
4
  import cv2
 
119
  repo='SmilingWolf/wd-v1-4-convnext-tagger'
120
  ),
121
  }
122
+ RE_SPECIAL = re.compile(r'([\\()])')
123
 
124
 
125
+ def image_to_wd14_tags(image: Image.Image, model_name: str, threshold: float,
126
+ use_spaces: bool, use_escape: bool, include_ranks: bool, score_descend: bool) \
127
+ -> Tuple[Mapping[str, float], str, Mapping[str, float]]:
128
  model = WAIFU_MODELS[model_name]
129
  ratings, tags = model.interrogate(image)
130
 
131
+ filtered_tags = {
132
  tag: score for tag, score in tags.items()
133
  if score >= threshold
134
  }
135
 
136
+ text_items = []
137
+ tags_pairs = filtered_tags.items()
138
+ if score_descend:
139
+ tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
140
+ for tag, score in tags_pairs:
141
+ tag_outformat = tag
142
+ if use_spaces:
143
+ tag_outformat = tag_outformat.replace('_', ' ')
144
+ if use_escape:
145
+ tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat)
146
+ if include_ranks:
147
+ tag_outformat = f"({tag_outformat}:{score:.3f})"
148
+ text_items.append(tag_outformat)
149
+ output_text = ', '.join(text_items)
150
+
151
+ return ratings, output_text, filtered_tags
152
+
153
 
154
  if __name__ == '__main__':
155
  interface = gr.Interface(
 
157
  inputs=[
158
  gr.Image(type='pil', label='Original Image'),
159
  gr.Radio(list(WAIFU_MODELS.keys()), value='wd14-vit', label='Waifu Model'),
160
+ gr.Slider(0.0, 1.0, value=0.5, label='Tagging Confidence Threshold'),
161
+ gr.Checkbox(value=False, label='Use Space Instead Of _'),
162
+ gr.Checkbox(value=True, label='Use Text Escape'),
163
+ gr.Checkbox(value=False, label='Keep Confidences'),
164
+ gr.Checkbox(value=True, label='Descend By Confidence'),
165
  ],
166
  outputs=[
167
  gr.Label(label='Ratings'),
168
+ gr.TextArea(label='Exported Text'),
169
  gr.Label(label='Tags'),
170
  ],
171
  interpretation="default"