DamarJati commited on
Commit
c0b6d7e
·
verified ·
1 Parent(s): e2b528e

Upload app (12).py

Browse files
Files changed (1) hide show
  1. app (12).py +378 -0
app (12).py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import argparse
4
+ import os
5
+
6
+ import gradio as gr
7
+ import huggingface_hub
8
+ import numpy as np
9
+ import onnxruntime as rt
10
+ import pandas as pd
11
+ from PIL import Image
12
+
13
+
14
+
15
+
16
+ # Daftar model dan ControlNet
17
+ models = ["Model A", "Model B", "Model C"]
18
+ vae = ["VAE A", "VAE B", "VAE C"]
19
+ controlnet_types = ["Canny", "Depth", "Normal", "Pose"]
20
+ schedulers = ["Euler", "LMS", "DDIM"]
21
+
22
+
23
+ # Fungsi placeholder
24
+ def load_model(selected_model):
25
+ return f"Model {selected_model} telah dimuat."
26
+
27
+ def generate_image(prompt, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model):
28
+ # Logika untuk menghasilkan gambar dari teks menggunakan model
29
+ return [f"Gambar {i+1} untuk prompt '{prompt}' dengan model '{model}'" for i in range(num_images)], {"prompt": prompt, "neg_prompt": neg_prompt}
30
+
31
+ def process_image(image, prompt, neg_prompt, model):
32
+ # Logika untuk memproses gambar menggunakan model
33
+ return f"Proses gambar dengan prompt '{prompt}' dan model '{model}'"
34
+
35
+ def controlnet_process(image, controlnet_type, model):
36
+ # Logika untuk memproses gambar menggunakan ControlNet
37
+ return f"Proses gambar dengan ControlNet '{controlnet_type}' dan model '{model}'"
38
+
39
+
40
+ def controlnet_process_func(image, controlnet_type, model):
41
+ # Update fungsi sesuai kebutuhan
42
+ return controlnet_process(image, controlnet_type, model)
43
+
44
+ def intpaint_func (image, controlnet_type, model):
45
+ # Update fungsi sesuai kebutuhan
46
+ return controlnet_process(image, controlnet_type, model)
47
+
48
+
49
+
50
+ #wd tagger
51
+
52
+ # Dataset v3 series of models:
53
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
54
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
55
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
56
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
57
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
58
+
59
+ # Dataset v2 series of models:
60
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
61
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
62
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
63
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
64
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
65
+
66
+ # Files to download from the repos
67
+ MODEL_FILENAME = "model.onnx"
68
+ LABEL_FILENAME = "selected_tags.csv"
69
+
70
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
71
+ kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ]
72
+
73
+ def parse_args() -> argparse.Namespace:
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
76
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
77
+ parser.add_argument("--score-character-threshold", type=float, default=0.85)
78
+ parser.add_argument("--share", action="store_true")
79
+ return parser.parse_args()
80
+
81
+
82
+ def load_labels(dataframe) -> list[str]:
83
+ name_series = dataframe["name"]
84
+ name_series = name_series.map(
85
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
86
+ )
87
+ tag_names = name_series.tolist()
88
+
89
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
90
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
91
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
92
+ return tag_names, rating_indexes, general_indexes, character_indexes
93
+
94
+
95
+ def mcut_threshold(probs):
96
+ """
97
+ Maximum Cut Thresholding (MCut)
98
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
99
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
100
+ (pp. 172-183).
101
+ """
102
+ sorted_probs = probs[probs.argsort()[::-1]]
103
+ difs = sorted_probs[:-1] - sorted_probs[1:]
104
+ t = difs.argmax()
105
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
106
+ return thresh
107
+
108
+
109
+ class Predictor:
110
+ def __init__(self):
111
+ self.model_target_size = None
112
+ self.last_loaded_repo = None
113
+
114
+ def download_model(self, model_repo):
115
+ csv_path = huggingface_hub.hf_hub_download(
116
+ model_repo,
117
+ LABEL_FILENAME,
118
+ )
119
+ model_path = huggingface_hub.hf_hub_download(
120
+ model_repo,
121
+ MODEL_FILENAME,
122
+ )
123
+ return csv_path, model_path
124
+
125
+ def load_model(self, model_repo):
126
+ if model_repo == self.last_loaded_repo:
127
+ return
128
+
129
+ csv_path, model_path = self.download_model(model_repo)
130
+
131
+ tags_df = pd.read_csv(csv_path)
132
+ sep_tags = load_labels(tags_df)
133
+
134
+ self.tag_names = sep_tags[0]
135
+ self.rating_indexes = sep_tags[1]
136
+ self.general_indexes = sep_tags[2]
137
+ self.character_indexes = sep_tags[3]
138
+
139
+ model = rt.InferenceSession(model_path)
140
+ _, height, width, _ = model.get_inputs()[0].shape
141
+ self.model_target_size = height
142
+
143
+ self.last_loaded_repo = model_repo
144
+ self.model = model
145
+
146
+ def prepare_image(self, image):
147
+ target_size = self.model_target_size
148
+
149
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
150
+ canvas.alpha_composite(image)
151
+ image = canvas.convert("RGB")
152
+
153
+ # Pad image to square
154
+ image_shape = image.size
155
+ max_dim = max(image_shape)
156
+ pad_left = (max_dim - image_shape[0]) // 2
157
+ pad_top = (max_dim - image_shape[1]) // 2
158
+
159
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
160
+ padded_image.paste(image, (pad_left, pad_top))
161
+
162
+ # Resize
163
+ if max_dim != target_size:
164
+ padded_image = padded_image.resize(
165
+ (target_size, target_size),
166
+ Image.BICUBIC,
167
+ )
168
+
169
+ # Convert to numpy array
170
+ image_array = np.asarray(padded_image, dtype=np.float32)
171
+
172
+ # Convert PIL-native RGB to BGR
173
+ image_array = image_array[:, :, ::-1]
174
+
175
+ return np.expand_dims(image_array, axis=0)
176
+
177
+
178
+ def predict(
179
+ self,
180
+ image,
181
+ model_repo,
182
+ general_thresh,
183
+ general_mcut_enabled,
184
+ character_thresh,
185
+ character_mcut_enabled,
186
+ ):
187
+ self.load_model(model_repo)
188
+
189
+ image = self.prepare_image(image)
190
+
191
+ input_name = self.model.get_inputs()[0].name
192
+ label_name = self.model.get_outputs()[0].name
193
+ preds = self.model.run([label_name], {input_name: image})[0]
194
+
195
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
196
+
197
+ # First 4 labels are actually ratings: pick one with argmax
198
+ ratings_names = [labels[i] for i in self.rating_indexes]
199
+ rating = dict(ratings_names)
200
+
201
+ # Then we have general tags: pick any where prediction confidence > threshold
202
+ general_names = [labels[i] for i in self.general_indexes]
203
+
204
+ if general_mcut_enabled:
205
+ general_probs = np.array([x[1] for x in general_names])
206
+ general_thresh = mcut_threshold(general_probs)
207
+
208
+ general_res = [x for x in general_names if x[1] > general_thresh]
209
+ general_res = dict(general_res)
210
+
211
+ # Everything else is characters: pick any where prediction confidence > threshold
212
+ character_names = [labels[i] for i in self.character_indexes]
213
+
214
+ if character_mcut_enabled:
215
+ character_probs = np.array([x[1] for x in character_names])
216
+ character_thresh = mcut_threshold(character_probs)
217
+ character_thresh = max(0.15, character_thresh)
218
+
219
+ character_res = [x for x in character_names if x[1] > character_thresh]
220
+ character_res = dict(character_res)
221
+
222
+ sorted_general_strings = sorted(
223
+ general_res.items(),
224
+ key=lambda x: x[1],
225
+ reverse=True,
226
+ )
227
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
228
+ sorted_general_strings = (
229
+ ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
230
+ )
231
+
232
+ return sorted_general_strings, rating, character_res, general_res
233
+
234
+
235
+
236
+ args = parse_args()
237
+ predictor = Predictor()
238
+
239
+ dropdown_list = [
240
+ SWINV2_MODEL_DSV3_REPO,
241
+ CONV_MODEL_DSV3_REPO,
242
+ VIT_MODEL_DSV3_REPO,
243
+ VIT_LARGE_MODEL_DSV3_REPO,
244
+ EVA02_LARGE_MODEL_DSV3_REPO,
245
+ MOAT_MODEL_DSV2_REPO,
246
+ SWIN_MODEL_DSV2_REPO,
247
+ CONV_MODEL_DSV2_REPO,
248
+ CONV2_MODEL_DSV2_REPO,
249
+ VIT_MODEL_DSV2_REPO,
250
+ ]
251
+
252
+ with gr.Blocks(css= "style.css") as app:
253
+ # Dropdown untuk memilih model di luar tab dengan lebar kecil
254
+ with gr.Row():
255
+ model_dropdown = gr.Dropdown(choices=models, label="Model", value="Model B")
256
+ vae_dropdown = gr.Dropdown(choices=vae, label="VAE", value="VAE C")
257
+
258
+ # Prompt dan Neg Prompt
259
+ with gr.Row():
260
+ with gr.Column(scale=1): # Scale 1 ensures full width
261
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan prompt teks", lines=2, elem_id="prompt-input")
262
+ neg_prompt_input = gr.Textbox(label="Neg Prompt", placeholder="Masukkan negasi prompt", lines=2, elem_id="neg-prompt-input")
263
+
264
+ generate_button = gr.Button("Generate", elem_id="generate-button", scale=0.13)
265
+
266
+
267
+ # Tab untuk Text-to-Image
268
+ with gr.Tab("Text-to-Image"):
269
+
270
+ with gr.Row():
271
+ with gr.Column():
272
+ # Konfigurasi
273
+ scheduler_input = gr.Dropdown(choices=schedulers, label="Sampling method", value=schedulers[0])
274
+ num_steps_input = gr.Slider(minimum=1, maximum=100, step=1, label="Sampling steps", value=20)
275
+ width_input = gr.Slider(minimum=128, maximum=2048, step=128, label="Width", value=512)
276
+ height_input = gr.Slider(minimum=128, maximum=2048, step=128, label="Height", value=512)
277
+ cfg_scale_input = gr.Slider(minimum=1, maximum=20, step=1, label="CFG Scale", value=7)
278
+ seed_input = gr.Number(label="Seed", value=-1)
279
+ batch_size = gr.Slider(minimum=1, maximum=24, step=1, label="Batch size", value=1)
280
+ batch_count = gr.Slider(minimum=1, maximum=24, step=1, label="Batch Count", value=1)
281
+
282
+ with gr.Accordion("Hires. fix"):
283
+ use_hires = gr.Checkbox(label="Use Hires?", value=False, scale=0)
284
+ with gr.Row(scale=1):
285
+ upscaler = gr.Dropdown(choices=schedulers, label="Upscaler", value=schedulers[0])
286
+ upscale_by = gr.Slider(minimum=1, maximum=8, step=1, label="Upscale by", value=2)
287
+ with gr.Row(scale=0.18):
288
+ hires_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Hires Steps", value=20)
289
+ denois_strength = gr.Slider(minimum=0, maximum=1, step=0.02, label="Denoising Strength", value=2)
290
+
291
+
292
+ with gr.Column():
293
+ # Gallery untuk output gambar
294
+ output_gallery = gr.Gallery(label="Hasil Gambar")
295
+ # Output teks JSON di bawah gallery
296
+ output_text = gr.Textbox(label="Output JSON", placeholder="Hasil dalam format JSON", lines=2)
297
+
298
+ def update_images(prompt, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model):
299
+ # Update fungsi sesuai kebutuhan
300
+ return generate_image(prompt, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model)
301
+
302
+ generate_button.click(fn=update_images, inputs=[prompt_input, neg_prompt_input, width_input, height_input, scheduler_input, num_steps_input, batch_size, batch_count, cfg_scale_input, seed_input, model_dropdown, vae_dropdown], outputs=[output_gallery, output_text])
303
+
304
+ # Tab untuk Image-to-Image
305
+ with gr.Tab("Image-to-Image"):
306
+ with gr.Row():
307
+ with gr.Column():
308
+ image_input = gr.Image(label="Unggah Gambar")
309
+ prompt_input_i2i = gr.Textbox(label="Prompt", placeholder="Masukkan prompt teks", lines=2)
310
+ neg_prompt_input_i2i = gr.Textbox(label="Neg Prompt", placeholder="Masukkan negasi prompt", lines=2)
311
+ generate_button_i2i = gr.Button("Proses Gambar")
312
+
313
+ with gr.Column():
314
+ output_image_i2i = gr.Image(label="Hasil Gambar")
315
+
316
+ def process_image_func(image, prompt, neg_prompt, model):
317
+ # Update fungsi sesuai kebutuhan
318
+ return process_image(image, prompt, neg_prompt, model)
319
+
320
+ generate_button_i2i.click(fn=process_image_func, inputs=[image_input, prompt_input_i2i, neg_prompt_input_i2i, model_dropdown, vae_dropdown], outputs=output_image_i2i)
321
+
322
+ # Tab untuk ControlNet
323
+ with gr.Tab("ControlNet"):
324
+ with gr.Row():
325
+ with gr.Column():
326
+ controlnet_dropdown = gr.Dropdown(choices=controlnet_types, label="Pilih Tipe ControlNet")
327
+ controlnet_image_input = gr.Image(label="Unggah Gambar untuk ControlNet")
328
+ controlnet_button = gr.Button("Proses dengan ControlNet")
329
+
330
+ with gr.Column():
331
+ controlnet_output_image = gr.Image(label="Hasil ControlNet")
332
+ controlnet_button.click(fn=controlnet_process_func, inputs=[controlnet_image_input, controlnet_dropdown, model_dropdown, vae_dropdown], outputs=controlnet_output_image)
333
+
334
+ # Tab untuk Intpainting
335
+ with gr.Tab ("Inpainting"):
336
+ with gr.Row():
337
+ with gr.Column():
338
+ image = gr.ImageMask(sources=["upload"], layers=False, transforms=[], format="png", label="base image", show_label=True)
339
+ btn = gr.Button("Inpaint!", elem_id="run_button")
340
+ prompt = gr.Textbox(placeholder="Your prompt (what you want in place of what is erased)", show_label=False, elem_id="prompt")
341
+ negative_prompt = gr.Textbox(label="negative_prompt", placeholder="Your negative prompt", info="what you don't want to see in the image")
342
+ guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale")
343
+ steps = gr.Number(value=20, minimum=10, maximum=30, step=1, label="steps")
344
+ strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength")
345
+ scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler")
346
+ with gr.Column():
347
+ image_out = gr.Image(label="Output", elem_id="output-img")
348
+
349
+ btn.click(fn=intpaint_func, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out])
350
+
351
+
352
+
353
+
354
+ # Tab untuk Describe
355
+ with gr.Tab("Describe"):
356
+ with gr.Row():
357
+ with gr.Column():
358
+ # Components
359
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
360
+ submit_button = gr.Button(value="Submit", variant="primary", size="lg")
361
+ model_repo = gr.Dropdown(dropdown_list, value=SWINV2_MODEL_DSV3_REPO, label="Model")
362
+ general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold", scale=3)
363
+ general_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1)
364
+ character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold", scale=3)
365
+ character_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1)
366
+ clear_button = gr.ClearButton(components=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled], variant="secondary", size="lg")
367
+
368
+ with gr.Column():
369
+ sorted_general_strings = gr.Textbox(label="Output (string)")
370
+ rating = gr.Label(label="Rating")
371
+ character_res = gr.Label(label="Output (characters)")
372
+ general_res = gr.Label(label="Output (tags)")
373
+
374
+ clear_button.add([sorted_general_strings, rating, character_res, general_res])
375
+ submit_button.click(predictor.predict, inputs=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled], outputs=[sorted_general_strings, rating, character_res, general_res])
376
+
377
+ # Jalankan antarmuka
378
+ app.launch()