John6666 commited on
Commit
9f748cc
1 Parent(s): d04e577

Upload 24 files

Browse files
Files changed (8) hide show
  1. app.py +16 -8
  2. constants.py +453 -0
  3. dc.py +278 -334
  4. env.py +15 -18
  5. llmdolphin.py +218 -196
  6. modutils.py +263 -50
  7. requirements.txt +2 -2
  8. utils.py +421 -0
app.py CHANGED
@@ -3,11 +3,11 @@ import gradio as gr
3
  import numpy as np
4
 
5
  # DiffuseCraft
6
- from dc import (infer, _infer, pass_result, get_diffusers_model_list, get_samplers,
7
  get_vaes, enable_model_recom_prompt, enable_diffusers_model_detail, extract_exif_data, esrgan_upscale, UPSCALER_KEYS,
8
  preset_quality, preset_styles, process_style_prompt, get_all_lora_tupled_list, update_loras, apply_lora_prompt,
9
- download_my_lora, search_civitai_lora, update_civitai_selection, select_civitai_lora, search_civitai_lora_json)
10
- from modutils import get_t2i_model_info, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL
11
  # Translator
12
  from llmdolphin import (dolphin_respond_auto, dolphin_parse_simple,
13
  get_llm_formats, get_dolphin_model_format, get_dolphin_models,
@@ -57,9 +57,15 @@ with gr.Blocks(fill_width=True, elem_id="container", css=css, delete_cache=(60,
57
  run_translate_button = gr.Button("Run with LLM Enhance", variant="secondary", scale=3)
58
  auto_trans = gr.Checkbox(label="Auto translate to English", value=False, scale=2)
59
 
60
- result = gr.Image(label="Result", elem_id="result", format="png", show_label=False, interactive=False,
61
- show_download_button=True, show_share_button=False, container=True)
62
-
 
 
 
 
 
 
63
  with gr.Accordion("Advanced Settings", open=False):
64
  with gr.Row():
65
  negative_prompt = gr.Text(label="Negative prompt", lines=1, max_lines=6, placeholder="Enter a negative prompt",
@@ -215,7 +221,7 @@ with gr.Blocks(fill_width=True, elem_id="container", css=css, delete_cache=(60,
215
  ).success(
216
  fn=dolphin_respond_auto,
217
  inputs=[prompt, chatbot],
218
- outputs=[chatbot],
219
  queue=True,
220
  show_progress="full",
221
  show_api=False,
@@ -238,6 +244,8 @@ with gr.Blocks(fill_width=True, elem_id="container", css=css, delete_cache=(60,
238
  ).success(lambda: None, None, chatbot, queue=False, show_api=False)\
239
  .success(pass_result, [result], [result], queue=False, show_api=False) # dummy fn for api
240
 
 
 
241
  gr.on(
242
  triggers=[lora1.change, lora1_wt.change, lora2.change, lora2_wt.change, lora3.change, lora3_wt.change,
243
  lora4.change, lora4_wt.change, lora5.change, lora5_wt.change],
@@ -425,4 +433,4 @@ with gr.Blocks(fill_width=True, elem_id="container", css=css, delete_cache=(60,
425
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
426
 
427
  demo.queue()
428
- demo.launch()
 
3
  import numpy as np
4
 
5
  # DiffuseCraft
6
+ from dc import (infer, _infer, pass_result, get_diffusers_model_list, get_samplers, save_image_history,
7
  get_vaes, enable_model_recom_prompt, enable_diffusers_model_detail, extract_exif_data, esrgan_upscale, UPSCALER_KEYS,
8
  preset_quality, preset_styles, process_style_prompt, get_all_lora_tupled_list, update_loras, apply_lora_prompt,
9
+ download_my_lora, search_civitai_lora, update_civitai_selection, select_civitai_lora, search_civitai_lora_json,
10
+ get_t2i_model_info, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL)
11
  # Translator
12
  from llmdolphin import (dolphin_respond_auto, dolphin_parse_simple,
13
  get_llm_formats, get_dolphin_model_format, get_dolphin_models,
 
57
  run_translate_button = gr.Button("Run with LLM Enhance", variant="secondary", scale=3)
58
  auto_trans = gr.Checkbox(label="Auto translate to English", value=False, scale=2)
59
 
60
+ result = gr.Image(label="Result", elem_id="result", format="png", type="filepath", show_label=False, interactive=False,
61
+ show_download_button=True, show_share_button=False, container=True)
62
+ with gr.Accordion("History", open=False):
63
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, show_share_button=False,
64
+ show_download_button=True)
65
+ history_files = gr.Files(interactive=False, visible=False)
66
+ history_clear_button = gr.Button(value="Clear History", variant="secondary")
67
+ history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
68
+
69
  with gr.Accordion("Advanced Settings", open=False):
70
  with gr.Row():
71
  negative_prompt = gr.Text(label="Negative prompt", lines=1, max_lines=6, placeholder="Enter a negative prompt",
 
221
  ).success(
222
  fn=dolphin_respond_auto,
223
  inputs=[prompt, chatbot],
224
+ outputs=[chatbot, result, prompt],
225
  queue=True,
226
  show_progress="full",
227
  show_api=False,
 
244
  ).success(lambda: None, None, chatbot, queue=False, show_api=False)\
245
  .success(pass_result, [result], [result], queue=False, show_api=False) # dummy fn for api
246
 
247
+ result.change(save_image_history, [result, history_gallery, history_files, model_name], [history_gallery, history_files], queue=False, show_api=False)
248
+
249
  gr.on(
250
  triggers=[lora1.change, lora1_wt.change, lora2.change, lora2_wt.change, lora3.change, lora3_wt.change,
251
  lora4.change, lora4_wt.change, lora5.change, lora5_wt.change],
 
433
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
434
 
435
  demo.queue()
436
+ demo.launch(show_error=True, debug=True)
constants.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
3
+ from stablepy import (
4
+ scheduler_names,
5
+ SD15_TASKS,
6
+ SDXL_TASKS,
7
+ )
8
+
9
+ # - **Download Models**
10
+ DOWNLOAD_MODEL = "https://civitai.com/api/download/models/574369, https://huggingface.co/TechnoByte/MilkyWonderland/resolve/main/milkyWonderland_v40.safetensors"
11
+
12
+ # - **Download VAEs**
13
+ DOWNLOAD_VAE = "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true, https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true, https://huggingface.co/digiplay/VAE/resolve/main/vividReal_v20.safetensors?download=true, https://huggingface.co/fp16-guy/anything_kl-f8-anime2_vae-ft-mse-840000-ema-pruned_blessed_clearvae_fp16_cleaned/resolve/main/vae-ft-mse-840000-ema-pruned_fp16.safetensors?download=true"
14
+
15
+ # - **Download LoRAs**
16
+ DOWNLOAD_LORA = "https://huggingface.co/Leopain/color/resolve/main/Coloring_book_-_LineArt.safetensors, https://civitai.com/api/download/models/135867, https://huggingface.co/Linaqruf/anime-detailer-xl-lora/resolve/main/anime-detailer-xl.safetensors?download=true, https://huggingface.co/Linaqruf/style-enhancer-xl-lora/resolve/main/style-enhancer-xl.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SD15-8steps-CFG-lora.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SDXL-8steps-CFG-lora.safetensors?download=true"
17
+
18
+ LOAD_DIFFUSERS_FORMAT_MODEL = [
19
+ 'stabilityai/stable-diffusion-xl-base-1.0',
20
+ 'black-forest-labs/FLUX.1-dev',
21
+ 'John6666/blue-pencil-flux1-v021-fp8-flux',
22
+ 'John6666/wai-ani-flux-v10forfp8-fp8-flux',
23
+ 'John6666/xe-anime-flux-v04-fp8-flux',
24
+ 'John6666/lyh-anime-flux-v2a1-fp8-flux',
25
+ 'John6666/carnival-unchained-v10-fp8-flux',
26
+ 'cagliostrolab/animagine-xl-3.1',
27
+ 'John6666/epicrealism-xl-v8kiss-sdxl',
28
+ 'misri/epicrealismXL_v7FinalDestination',
29
+ 'misri/juggernautXL_juggernautX',
30
+ 'misri/zavychromaxl_v80',
31
+ 'SG161222/RealVisXL_V4.0',
32
+ 'SG161222/RealVisXL_V5.0',
33
+ 'misri/newrealityxlAllInOne_Newreality40',
34
+ 'eienmojiki/Anything-XL',
35
+ 'eienmojiki/Starry-XL-v5.2',
36
+ 'gsdf/CounterfeitXL',
37
+ 'KBlueLeaf/Kohaku-XL-Zeta',
38
+ 'John6666/silvermoon-mix-01xl-v11-sdxl',
39
+ 'WhiteAiZ/autismmixSDXL_autismmixConfetti_diffusers',
40
+ 'kitty7779/ponyDiffusionV6XL',
41
+ 'GraydientPlatformAPI/aniverse-pony',
42
+ 'John6666/ras-real-anime-screencap-v1-sdxl',
43
+ 'John6666/duchaiten-pony-xl-no-score-v60-sdxl',
44
+ 'John6666/mistoon-anime-ponyalpha-sdxl',
45
+ 'John6666/3x3x3mixxl-v2-sdxl',
46
+ 'John6666/3x3x3mixxl-3dv01-sdxl',
47
+ 'John6666/ebara-mfcg-pony-mix-v12-sdxl',
48
+ 'John6666/t-ponynai3-v51-sdxl',
49
+ 'John6666/t-ponynai3-v65-sdxl',
50
+ 'John6666/prefect-pony-xl-v3-sdxl',
51
+ 'John6666/mala-anime-mix-nsfw-pony-xl-v5-sdxl',
52
+ 'John6666/wai-real-mix-v11-sdxl',
53
+ 'John6666/wai-c-v6-sdxl',
54
+ 'John6666/iniverse-mix-xl-sfwnsfw-pony-guofeng-v43-sdxl',
55
+ 'John6666/photo-realistic-pony-v5-sdxl',
56
+ 'John6666/pony-realism-v21main-sdxl',
57
+ 'John6666/pony-realism-v22main-sdxl',
58
+ 'John6666/cyberrealistic-pony-v63-sdxl',
59
+ 'John6666/cyberrealistic-pony-v64-sdxl',
60
+ 'John6666/cyberrealistic-pony-v65-sdxl',
61
+ 'GraydientPlatformAPI/realcartoon-pony-diffusion',
62
+ 'John6666/nova-anime-xl-pony-v5-sdxl',
63
+ 'John6666/autismmix-sdxl-autismmix-pony-sdxl',
64
+ 'John6666/aimz-dream-real-pony-mix-v3-sdxl',
65
+ 'John6666/duchaiten-pony-real-v11fix-sdxl',
66
+ 'John6666/duchaiten-pony-real-v20-sdxl',
67
+ 'yodayo-ai/kivotos-xl-2.0',
68
+ 'yodayo-ai/holodayo-xl-2.1',
69
+ 'yodayo-ai/clandestine-xl-1.0',
70
+ 'digiplay/majicMIX_sombre_v2',
71
+ 'digiplay/majicMIX_realistic_v6',
72
+ 'digiplay/majicMIX_realistic_v7',
73
+ 'digiplay/DreamShaper_8',
74
+ 'digiplay/BeautifulArt_v1',
75
+ 'digiplay/DarkSushi2.5D_v1',
76
+ 'digiplay/darkphoenix3D_v1.1',
77
+ 'digiplay/BeenYouLiteL11_diffusers',
78
+ 'Yntec/RevAnimatedV2Rebirth',
79
+ 'youknownothing/cyberrealistic_v50',
80
+ 'youknownothing/deliberate-v6',
81
+ 'GraydientPlatformAPI/deliberate-cyber3',
82
+ 'GraydientPlatformAPI/picx-real',
83
+ 'GraydientPlatformAPI/perfectworld6',
84
+ 'emilianJR/epiCRealism',
85
+ 'votepurchase/counterfeitV30_v30',
86
+ 'votepurchase/ChilloutMix',
87
+ 'Meina/MeinaMix_V11',
88
+ 'Meina/MeinaUnreal_V5',
89
+ 'Meina/MeinaPastel_V7',
90
+ 'GraydientPlatformAPI/realcartoon3d-17',
91
+ 'GraydientPlatformAPI/realcartoon-pixar11',
92
+ 'GraydientPlatformAPI/realcartoon-real17',
93
+ ]
94
+
95
+ DIFFUSERS_FORMAT_LORAS = [
96
+ "nerijs/animation2k-flux",
97
+ "XLabs-AI/flux-RealismLora",
98
+ ]
99
+
100
+ DOWNLOAD_EMBEDS = [
101
+ 'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
102
+ 'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
103
+ 'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
104
+ ]
105
+
106
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
107
+ HF_TOKEN = os.environ.get("HF_READ_TOKEN")
108
+
109
+ DIRECTORY_MODELS = 'models'
110
+ DIRECTORY_LORAS = 'loras'
111
+ DIRECTORY_VAES = 'vaes'
112
+ DIRECTORY_EMBEDS = 'embedings'
113
+
114
+ PREPROCESSOR_CONTROLNET = {
115
+ "openpose": [
116
+ "Openpose",
117
+ "None",
118
+ ],
119
+ "scribble": [
120
+ "HED",
121
+ "PidiNet",
122
+ "None",
123
+ ],
124
+ "softedge": [
125
+ "PidiNet",
126
+ "HED",
127
+ "HED safe",
128
+ "PidiNet safe",
129
+ "None",
130
+ ],
131
+ "segmentation": [
132
+ "UPerNet",
133
+ "None",
134
+ ],
135
+ "depth": [
136
+ "DPT",
137
+ "Midas",
138
+ "None",
139
+ ],
140
+ "normalbae": [
141
+ "NormalBae",
142
+ "None",
143
+ ],
144
+ "lineart": [
145
+ "Lineart",
146
+ "Lineart coarse",
147
+ "Lineart (anime)",
148
+ "None",
149
+ "None (anime)",
150
+ ],
151
+ "lineart_anime": [
152
+ "Lineart",
153
+ "Lineart coarse",
154
+ "Lineart (anime)",
155
+ "None",
156
+ "None (anime)",
157
+ ],
158
+ "shuffle": [
159
+ "ContentShuffle",
160
+ "None",
161
+ ],
162
+ "canny": [
163
+ "Canny",
164
+ "None",
165
+ ],
166
+ "mlsd": [
167
+ "MLSD",
168
+ "None",
169
+ ],
170
+ "ip2p": [
171
+ "ip2p"
172
+ ],
173
+ "recolor": [
174
+ "Recolor luminance",
175
+ "Recolor intensity",
176
+ "None",
177
+ ],
178
+ "tile": [
179
+ "Mild Blur",
180
+ "Moderate Blur",
181
+ "Heavy Blur",
182
+ "None",
183
+ ],
184
+
185
+ }
186
+
187
+ TASK_STABLEPY = {
188
+ 'txt2img': 'txt2img',
189
+ 'img2img': 'img2img',
190
+ 'inpaint': 'inpaint',
191
+ # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
192
+ # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
193
+ # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
194
+ # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
195
+ # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
196
+ 'openpose ControlNet': 'openpose',
197
+ 'canny ControlNet': 'canny',
198
+ 'mlsd ControlNet': 'mlsd',
199
+ 'scribble ControlNet': 'scribble',
200
+ 'softedge ControlNet': 'softedge',
201
+ 'segmentation ControlNet': 'segmentation',
202
+ 'depth ControlNet': 'depth',
203
+ 'normalbae ControlNet': 'normalbae',
204
+ 'lineart ControlNet': 'lineart',
205
+ 'lineart_anime ControlNet': 'lineart_anime',
206
+ 'shuffle ControlNet': 'shuffle',
207
+ 'ip2p ControlNet': 'ip2p',
208
+ 'optical pattern ControlNet': 'pattern',
209
+ 'recolor ControlNet': 'recolor',
210
+ 'tile ControlNet': 'tile',
211
+ }
212
+
213
+ TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
214
+
215
+ UPSCALER_DICT_GUI = {
216
+ None: None,
217
+ "Lanczos": "Lanczos",
218
+ "Nearest": "Nearest",
219
+ 'Latent': 'Latent',
220
+ 'Latent (antialiased)': 'Latent (antialiased)',
221
+ 'Latent (bicubic)': 'Latent (bicubic)',
222
+ 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
223
+ 'Latent (nearest)': 'Latent (nearest)',
224
+ 'Latent (nearest-exact)': 'Latent (nearest-exact)',
225
+ "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
226
+ "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
227
+ "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
228
+ "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
229
+ "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
230
+ "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
231
+ "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
232
+ "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
233
+ "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
234
+ "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
235
+ "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
236
+ "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
237
+ "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
238
+ "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
239
+ }
240
+
241
+ UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
242
+
243
+ PROMPT_W_OPTIONS = [
244
+ ("Compel format: (word)weight", "Compel"),
245
+ ("Classic format: (word:weight)", "Classic"),
246
+ ("Classic-original format: (word:weight)", "Classic-original"),
247
+ ("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
248
+ ("Classic-ignore", "Classic-ignore"),
249
+ ("None", "None"),
250
+ ]
251
+
252
+ WARNING_MSG_VAE = (
253
+ "Use the right VAE for your model to maintain image quality. The wrong"
254
+ " VAE can lead to poor results, like blurriness in the generated images."
255
+ )
256
+
257
+ SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
258
+ SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
259
+ FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
260
+
261
+ MODEL_TYPE_TASK = {
262
+ "SD 1.5": SD_TASK,
263
+ "SDXL": SDXL_TASK,
264
+ "FLUX": FLUX_TASK,
265
+ }
266
+
267
+ MODEL_TYPE_CLASS = {
268
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
269
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
270
+ "diffusers:FluxPipeline": "FLUX",
271
+ }
272
+
273
+ POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
274
+
275
+ SUBTITLE_GUI = (
276
+ "### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
277
+ " to perform different tasks in image generation."
278
+ )
279
+
280
+ HELP_GUI = (
281
+ """### Help:
282
+ - The current space runs on a ZERO GPU which is assigned for approximately 60 seconds; Therefore, if you submit expensive tasks, the operation may be canceled upon reaching the maximum allowed time with 'GPU TASK ABORTED'.
283
+ - Distorted or strange images often result from high prompt weights, so it's best to use low weights and scales, and consider using Classic variants like 'Classic-original'.
284
+ - For better results with Pony Diffusion, try using sampler DPM++ 1s or DPM2 with Compel or Classic prompt weights.
285
+ """
286
+ )
287
+
288
+ EXAMPLES_GUI_HELP = (
289
+ """### The following examples perform specific tasks:
290
+ 1. Generation with SDXL and upscale
291
+ 2. Generation with FLUX dev
292
+ 3. ControlNet Canny SDXL
293
+ 4. Optical pattern (Optical illusion) SDXL
294
+ 5. Convert an image to a coloring drawing
295
+ 6. ControlNet OpenPose SD 1.5 and Latent upscale
296
+
297
+ - Different tasks can be performed, such as img2img or using the IP adapter, to preserve a person's appearance or a specific style based on an image.
298
+ """
299
+ )
300
+
301
+ EXAMPLES_GUI = [
302
+ [
303
+ "1girl, souryuu asuka langley, neon genesis evangelion, rebuild of evangelion, lance of longinus, cat hat, plugsuit, pilot suit, red bodysuit, sitting, crossed legs, black eye patch, throne, looking down, from bottom, looking at viewer, outdoors, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
304
+ "nfsw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, unfinished, very displeasing, oldest, early, chromatic aberration, artistic error, scan, abstract",
305
+ 28,
306
+ 7.0,
307
+ -1,
308
+ "None",
309
+ 0.33,
310
+ "Euler a",
311
+ 1152,
312
+ 896,
313
+ "cagliostrolab/animagine-xl-3.1",
314
+ "txt2img",
315
+ "image.webp", # img conttol
316
+ 1024, # img resolution
317
+ 0.35, # strength
318
+ 1.0, # cn scale
319
+ 0.0, # cn start
320
+ 1.0, # cn end
321
+ "Classic",
322
+ "Nearest",
323
+ 45,
324
+ False,
325
+ ],
326
+ [
327
+ "a digital illustration of a movie poster titled 'Finding Emo', finding nemo parody poster, featuring a depressed cartoon clownfish with black emo hair, eyeliner, and piercings, bored expression, swimming in a dark underwater scene, in the background, movie title in a dripping, grungy font, moody blue and purple color palette",
328
+ "",
329
+ 24,
330
+ 3.5,
331
+ -1,
332
+ "None",
333
+ 0.33,
334
+ "Euler a",
335
+ 1152,
336
+ 896,
337
+ "black-forest-labs/FLUX.1-dev",
338
+ "txt2img",
339
+ None, # img conttol
340
+ 1024, # img resolution
341
+ 0.35, # strength
342
+ 1.0, # cn scale
343
+ 0.0, # cn start
344
+ 1.0, # cn end
345
+ "Classic",
346
+ None,
347
+ 70,
348
+ True,
349
+ ],
350
+ [
351
+ "((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff",
352
+ "(worst quality:1.2), (bad quality:1.2), (poor quality:1.2), (missing fingers:1.2), bad-artist-anime, bad-artist, bad-picture-chill-75v",
353
+ 48,
354
+ 3.5,
355
+ -1,
356
+ "None",
357
+ 0.33,
358
+ "DPM++ 2M SDE Lu",
359
+ 1024,
360
+ 1024,
361
+ "misri/epicrealismXL_v7FinalDestination",
362
+ "canny ControlNet",
363
+ "image.webp", # img conttol
364
+ 1024, # img resolution
365
+ 0.35, # strength
366
+ 1.0, # cn scale
367
+ 0.0, # cn start
368
+ 1.0, # cn end
369
+ "Classic",
370
+ None,
371
+ 44,
372
+ False,
373
+ ],
374
+ [
375
+ "cinematic scenery old city ruins",
376
+ "(worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), (illustration, 3d, 2d, painting, cartoons, sketch, blurry, film grain, noise), (low quality, worst quality:1.2)",
377
+ 50,
378
+ 4.0,
379
+ -1,
380
+ "None",
381
+ 0.33,
382
+ "Euler a",
383
+ 1024,
384
+ 1024,
385
+ "misri/juggernautXL_juggernautX",
386
+ "optical pattern ControlNet",
387
+ "spiral_no_transparent.png", # img conttol
388
+ 1024, # img resolution
389
+ 0.35, # strength
390
+ 1.0, # cn scale
391
+ 0.05, # cn start
392
+ 0.75, # cn end
393
+ "Classic",
394
+ None,
395
+ 35,
396
+ False,
397
+ ],
398
+ [
399
+ "black and white, line art, coloring drawing, clean line art, black strokes, no background, white, black, free lines, black scribbles, on paper, A blend of comic book art and lineart full of black and white color, masterpiece, high-resolution, trending on Pixiv fan box, palette knife, brush strokes, two-dimensional, planar vector, T-shirt design, stickers, and T-shirt design, vector art, fantasy art, Adobe Illustrator, hand-painted, digital painting, low polygon, soft lighting, aerial view, isometric style, retro aesthetics, 8K resolution, black sketch lines, monochrome, invert color",
400
+ "color, red, green, yellow, colored, duplicate, blurry, abstract, disfigured, deformed, animated, toy, figure, framed, 3d, bad art, poorly drawn, extra limbs, close up, b&w, weird colors, blurry, watermark, blur haze, 2 heads, long neck, watermark, elongated body, cropped image, out of frame, draft, deformed hands, twisted fingers, double image, malformed hands, multiple heads, extra limb, ugly, poorly drawn hands, missing limb, cut-off, over satured, grain, lowères, bad anatomy, poorly drawn face, mutation, mutated, floating limbs, disconnected limbs, out of focus, long body, disgusting, extra fingers, groos proportions, missing arms, mutated hands, cloned face, missing legs, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, bluelish, blue",
401
+ 20,
402
+ 4.0,
403
+ -1,
404
+ "loras/Coloring_book_-_LineArt.safetensors",
405
+ 1.0,
406
+ "DPM++ 2M SDE Karras",
407
+ 1024,
408
+ 1024,
409
+ "cagliostrolab/animagine-xl-3.1",
410
+ "lineart ControlNet",
411
+ "color_image.png", # img conttol
412
+ 896, # img resolution
413
+ 0.35, # strength
414
+ 1.0, # cn scale
415
+ 0.0, # cn start
416
+ 1.0, # cn end
417
+ "Compel",
418
+ None,
419
+ 35,
420
+ False,
421
+ ],
422
+ [
423
+ "1girl,face,curly hair,red hair,white background,",
424
+ "(worst quality:2),(low quality:2),(normal quality:2),lowres,watermark,",
425
+ 38,
426
+ 5.0,
427
+ -1,
428
+ "None",
429
+ 0.33,
430
+ "DPM++ 2M SDE Karras",
431
+ 512,
432
+ 512,
433
+ "digiplay/majicMIX_realistic_v7",
434
+ "openpose ControlNet",
435
+ "image.webp", # img conttol
436
+ 1024, # img resolution
437
+ 0.35, # strength
438
+ 1.0, # cn scale
439
+ 0.0, # cn start
440
+ 0.9, # cn end
441
+ "Compel",
442
+ "Latent (antialiased)",
443
+ 46,
444
+ False,
445
+ ],
446
+ ]
447
+
448
+ RESOURCES = (
449
+ """### Resources
450
+ - John6666's space has some great features you might find helpful [link](https://huggingface.co/spaces/John6666/DiffuseCraftMod).
451
+ - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
452
+ """
453
+ )
dc.py CHANGED
@@ -1,33 +1,52 @@
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
5
- from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
6
  import torch
7
  import re
8
- from huggingface_hub import HfApi
9
  from stablepy import (
10
- CONTROLNET_MODEL_IDS,
11
- VALID_TASKS,
12
- T2I_PREPROCESSOR_NAME,
13
- FLASH_LORA,
14
- SCHEDULER_CONFIG_MAP,
15
  scheduler_names,
16
- IP_ADAPTER_MODELS,
17
  IP_ADAPTERS_SD,
18
  IP_ADAPTERS_SDXL,
19
- REPO_IMAGE_ENCODER,
20
- ALL_PROMPT_WEIGHT_OPTIONS,
21
- SD15_TASKS,
22
- SDXL_TASKS,
23
  )
24
  import time
25
  from PIL import ImageFile
26
- #import urllib.parse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ImageFile.LOAD_TRUNCATED_IMAGES = True
 
29
  print(os.getenv("SPACES_ZERO_GPU"))
30
 
 
31
  import gradio as gr
32
  import logging
33
  logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -38,205 +57,63 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffuse
38
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
39
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
40
  from stablepy import logger
41
- logger.setLevel(logging.CRITICAL)
42
 
43
  from env import (
44
- HF_TOKEN, hf_read_token, # to use only for private repos
45
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
46
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
47
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
48
- directory_models, directory_loras, directory_vaes, directory_embeds,
49
- directory_embeds_sdxl, directory_embeds_positive_sdxl,
50
- load_diffusers_format_model, download_model_list, download_lora_list,
51
- download_vae_list, download_embeds)
52
-
53
- PREPROCESSOR_CONTROLNET = {
54
- "openpose": [
55
- "Openpose",
56
- "None",
57
- ],
58
- "scribble": [
59
- "HED",
60
- "PidiNet",
61
- "None",
62
- ],
63
- "softedge": [
64
- "PidiNet",
65
- "HED",
66
- "HED safe",
67
- "PidiNet safe",
68
- "None",
69
- ],
70
- "segmentation": [
71
- "UPerNet",
72
- "None",
73
- ],
74
- "depth": [
75
- "DPT",
76
- "Midas",
77
- "None",
78
- ],
79
- "normalbae": [
80
- "NormalBae",
81
- "None",
82
- ],
83
- "lineart": [
84
- "Lineart",
85
- "Lineart coarse",
86
- "Lineart (anime)",
87
- "None",
88
- "None (anime)",
89
- ],
90
- "lineart_anime": [
91
- "Lineart",
92
- "Lineart coarse",
93
- "Lineart (anime)",
94
- "None",
95
- "None (anime)",
96
- ],
97
- "shuffle": [
98
- "ContentShuffle",
99
- "None",
100
- ],
101
- "canny": [
102
- "Canny",
103
- "None",
104
- ],
105
- "mlsd": [
106
- "MLSD",
107
- "None",
108
- ],
109
- "ip2p": [
110
- "ip2p"
111
- ],
112
- "recolor": [
113
- "Recolor luminance",
114
- "Recolor intensity",
115
- "None",
116
- ],
117
- "tile": [
118
- "Mild Blur",
119
- "Moderate Blur",
120
- "Heavy Blur",
121
- "None",
122
- ],
123
- }
124
-
125
- TASK_STABLEPY = {
126
- 'txt2img': 'txt2img',
127
- 'img2img': 'img2img',
128
- 'inpaint': 'inpaint',
129
- # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
130
- # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
131
- # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
132
- # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
133
- # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
134
- 'openpose ControlNet': 'openpose',
135
- 'canny ControlNet': 'canny',
136
- 'mlsd ControlNet': 'mlsd',
137
- 'scribble ControlNet': 'scribble',
138
- 'softedge ControlNet': 'softedge',
139
- 'segmentation ControlNet': 'segmentation',
140
- 'depth ControlNet': 'depth',
141
- 'normalbae ControlNet': 'normalbae',
142
- 'lineart ControlNet': 'lineart',
143
- 'lineart_anime ControlNet': 'lineart_anime',
144
- 'shuffle ControlNet': 'shuffle',
145
- 'ip2p ControlNet': 'ip2p',
146
- 'optical pattern ControlNet': 'pattern',
147
- 'recolor ControlNet': 'recolor',
148
- 'tile ControlNet': 'tile',
149
- }
150
-
151
- TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
152
-
153
- UPSCALER_DICT_GUI = {
154
- None: None,
155
- "Lanczos": "Lanczos",
156
- "Nearest": "Nearest",
157
- 'Latent': 'Latent',
158
- 'Latent (antialiased)': 'Latent (antialiased)',
159
- 'Latent (bicubic)': 'Latent (bicubic)',
160
- 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
161
- 'Latent (nearest)': 'Latent (nearest)',
162
- 'Latent (nearest-exact)': 'Latent (nearest-exact)',
163
- "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
164
- "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
165
- "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
166
- "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
167
- "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
168
- "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
169
- "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
170
- "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
171
- "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
172
- "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
173
- "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
174
- "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
175
- "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
176
- "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
177
- }
178
-
179
- UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
180
-
181
-
182
- def get_model_list(directory_path):
183
- model_list = []
184
- valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
185
-
186
- for filename in os.listdir(directory_path):
187
- if os.path.splitext(filename)[1] in valid_extensions:
188
- # name_without_extension = os.path.splitext(filename)[0]
189
- file_path = os.path.join(directory_path, filename)
190
- # model_list.append((name_without_extension, file_path))
191
- model_list.append(file_path)
192
- print('\033[34mFILE: ' + file_path + '\033[0m')
193
- return model_list
194
 
195
- ## BEGIN MOD
196
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
197
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
198
 
199
  # - **Download Models**
200
- download_model = ", ".join(download_model_list)
201
  # - **Download VAEs**
202
- download_vae = ", ".join(download_vae_list)
203
  # - **Download LoRAs**
204
- download_lora = ", ".join(download_lora_list)
205
 
206
- #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
207
- download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
208
 
209
- load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
210
  ## END MOD
211
 
212
  # Download stuffs
213
  for url in [url.strip() for url in download_model.split(',')]:
214
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
215
- download_things(directory_models, url, HF_TOKEN, CIVITAI_API_KEY)
216
  for url in [url.strip() for url in download_vae.split(',')]:
217
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
218
- download_things(directory_vaes, url, HF_TOKEN, CIVITAI_API_KEY)
219
  for url in [url.strip() for url in download_lora.split(',')]:
220
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
221
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
222
 
223
  # Download Embeddings
224
- for url_embed in download_embeds:
225
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
226
- download_things(directory_embeds, url_embed, HF_TOKEN, CIVITAI_API_KEY)
227
 
228
  # Build list models
229
- embed_list = get_model_list(directory_embeds)
230
- model_list = get_model_list(directory_models)
231
  model_list = load_diffusers_format_model + model_list
 
232
  ## BEGIN MOD
233
  lora_model_list = get_lora_model_list()
234
- vae_model_list = get_model_list(directory_vaes)
235
  vae_model_list.insert(0, "None")
236
 
237
- #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, directory_embeds_sdxl, False)
238
- #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, directory_embeds_positive_sdxl, False)
239
- embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_positive_sdxl)
240
 
241
  def get_embed_list(pipeline_name):
242
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
@@ -244,99 +121,13 @@ def get_embed_list(pipeline_name):
244
 
245
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
246
 
247
- msg_inc_vae = (
248
- "Use the right VAE for your model to maintain image quality. The wrong"
249
- " VAE can lead to poor results, like blurriness in the generated images."
250
- )
251
-
252
- SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
253
- SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
254
- FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
255
-
256
- MODEL_TYPE_TASK = {
257
- "SD 1.5": SD_TASK,
258
- "SDXL": SDXL_TASK,
259
- "FLUX": FLUX_TASK,
260
- }
261
-
262
- MODEL_TYPE_CLASS = {
263
- "diffusers:StableDiffusionPipeline": "SD 1.5",
264
- "diffusers:StableDiffusionXLPipeline": "SDXL",
265
- "diffusers:FluxPipeline": "FLUX",
266
- }
267
-
268
- POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
269
-
270
- def extract_parameters(input_string):
271
- parameters = {}
272
- input_string = input_string.replace("\n", "")
273
-
274
- if "Negative prompt:" not in input_string:
275
- if "Steps:" in input_string:
276
- input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
277
- else:
278
- print("Invalid metadata")
279
- parameters["prompt"] = input_string
280
- return parameters
281
-
282
- parm = input_string.split("Negative prompt:")
283
- parameters["prompt"] = parm[0].strip()
284
- if "Steps:" not in parm[1]:
285
- print("Steps not detected")
286
- parameters["neg_prompt"] = parm[1].strip()
287
- return parameters
288
- parm = parm[1].split("Steps:")
289
- parameters["neg_prompt"] = parm[0].strip()
290
- input_string = "Steps:" + parm[1]
291
-
292
- # Extracting Steps
293
- steps_match = re.search(r'Steps: (\d+)', input_string)
294
- if steps_match:
295
- parameters['Steps'] = int(steps_match.group(1))
296
-
297
- # Extracting Size
298
- size_match = re.search(r'Size: (\d+x\d+)', input_string)
299
- if size_match:
300
- parameters['Size'] = size_match.group(1)
301
- width, height = map(int, parameters['Size'].split('x'))
302
- parameters['width'] = width
303
- parameters['height'] = height
304
-
305
- # Extracting other parameters
306
- other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
307
- for param in other_parameters:
308
- parameters[param[0]] = param[1].strip('"')
309
-
310
- return parameters
311
-
312
- def get_model_type(repo_id: str):
313
- api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
314
- default = "SD 1.5"
315
- try:
316
- model = api.model_info(repo_id=repo_id, timeout=5.0)
317
- tags = model.tags
318
- for tag in tags:
319
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
320
- except Exception:
321
- return default
322
- return default
323
-
324
  ## BEGIN MOD
325
  class GuiSD:
326
- def __init__(self):
327
  self.model = None
328
-
329
- print("Loading model...")
330
- self.model = Model_Diffusers(
331
- base_model_id="Lykon/dreamshaper-8",
332
- task_name="txt2img",
333
- vae_model=None,
334
- type_model_precision=torch.float16,
335
- retain_task_model_in_cache=False,
336
- device="cpu",
337
- )
338
- self.model.load_beta_styles()
339
- #self.model.device = torch.device("cpu") #
340
 
341
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
342
  #progress(0, desc="Start inference...")
@@ -350,28 +141,83 @@ class GuiSD:
350
  return img
351
 
352
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
353
-
354
- #yield f"Loading model: {model_name}"
355
-
356
  vae_model = vae_model if vae_model != "None" else None
357
  model_type = get_model_type(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  if vae_model:
360
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
361
  if model_type != vae_type:
362
- gr.Warning(msg_inc_vae)
363
 
364
- self.model.device = torch.device("cpu")
365
- dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
366
-
367
- self.model.load_pipe(
368
- model_name,
369
- task_name=TASK_STABLEPY[task],
370
- vae_model=vae_model if vae_model != "None" else None,
371
- type_model_precision=dtype_model,
372
- retain_task_model_in_cache=False,
373
- )
374
- #yield f"Model loaded: {model_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  #@spaces.GPU
377
  @torch.inference_mode()
@@ -479,23 +325,24 @@ class GuiSD:
479
  mode_ip2,
480
  scale_ip2,
481
  pag_scale,
482
- #progress=gr.Progress(track_tqdm=True),
483
  ):
484
- #progress(0, desc="Preparing inference...")
485
-
 
486
  vae_model = vae_model if vae_model != "None" else None
487
  loras_list = [lora1, lora2, lora3, lora4, lora5]
488
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
489
  msg_lora = ""
490
 
491
- print("Config model:", model_name, vae_model, loras_list)
492
-
493
  ## BEGIN MOD
 
494
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
495
  global lora_model_list
496
  lora_model_list = get_lora_model_list()
497
  ## END MOD
498
 
 
 
499
  task = TASK_STABLEPY[task]
500
 
501
  params_ip_img = []
@@ -518,6 +365,9 @@ class GuiSD:
518
  params_ip_mode.append(modeip)
519
  params_ip_scale.append(scaleip)
520
 
 
 
 
521
  if task != "txt2img" and not image_control:
522
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
523
 
@@ -589,15 +439,15 @@ class GuiSD:
589
  "high_threshold": high_threshold,
590
  "value_threshold": value_threshold,
591
  "distance_threshold": distance_threshold,
592
- "lora_A": lora1 if lora1 != "None" and lora1 != "" else None,
593
  "lora_scale_A": lora_scale1,
594
- "lora_B": lora2 if lora2 != "None" and lora2 != "" else None,
595
  "lora_scale_B": lora_scale2,
596
- "lora_C": lora3 if lora3 != "None" and lora3 != "" else None,
597
  "lora_scale_C": lora_scale3,
598
- "lora_D": lora4 if lora4 != "None" and lora4 != "" else None,
599
  "lora_scale_D": lora_scale4,
600
- "lora_E": lora5 if lora5 != "None" and lora5 != "" else None,
601
  "lora_scale_E": lora_scale5,
602
  ## BEGIN MOD
603
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
@@ -647,21 +497,61 @@ class GuiSD:
647
  }
648
 
649
  self.model.device = torch.device("cuda:0")
650
- if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5 and loras_list != [""] * 5:
651
  self.model.pipe.transformer.to(self.model.device)
652
  print("transformer to cuda")
653
 
654
- #progress(1, desc="Inference preparation completed. Starting inference...")
655
-
656
- info_state = "" # for yield version
657
- return self.infer_short(self.model, pipe_params), info_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  ## END MOD
 
 
 
 
659
 
660
  def dynamic_gpu_duration(func, duration, *args):
661
 
662
  @spaces.GPU(duration=duration)
663
  def wrapped_func():
664
- return func(*args)
665
 
666
  return wrapped_func()
667
 
@@ -678,7 +568,7 @@ def sd_gen_generate_pipeline(*args):
678
  load_lora_cpu = args[-3]
679
  generation_args = args[:-3]
680
  lora_list = [
681
- None if item == "None" or item == "" else item
682
  for item in [args[7], args[9], args[11], args[13], args[15]]
683
  ]
684
  lora_status = [None] * 5
@@ -687,8 +577,8 @@ def sd_gen_generate_pipeline(*args):
687
  if load_lora_cpu:
688
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
689
 
690
- #if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
691
- # yield None, msg_load_lora
692
 
693
  # Load lora in CPU
694
  if load_lora_cpu:
@@ -714,46 +604,36 @@ def sd_gen_generate_pipeline(*args):
714
  )
715
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
716
 
717
- msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
 
718
  gr.Info(msg_request)
719
  print(msg_request)
720
-
721
- # yield from sd_gen.generate_pipeline(*generation_args)
722
 
723
  start_time = time.time()
724
 
725
- return dynamic_gpu_duration(
 
 
726
  sd_gen.generate_pipeline,
727
  gpu_duration_arg,
728
  *generation_args,
729
  )
730
 
731
  end_time = time.time()
 
 
 
 
732
 
733
  if verbose_arg:
734
- execution_time = end_time - start_time
735
- msg_task_complete = (
736
- f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
737
- )
738
  gr.Info(msg_task_complete)
739
  print(msg_task_complete)
740
 
741
- def extract_exif_data(image):
742
- if image is None: return ""
743
-
744
- try:
745
- metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
746
 
747
- for key in metadata_keys:
748
- if key in image.info:
749
- return image.info[key]
750
 
751
- return str(image.info)
752
-
753
- except Exception as e:
754
- return f"Error extracting metadata: {str(e)}"
755
-
756
- @spaces.GPU(duration=20)
757
  def esrgan_upscale(image, upscaler_name, upscaler_size):
758
  if image is None: return None
759
 
@@ -775,18 +655,22 @@ def esrgan_upscale(image, upscaler_name, upscaler_size):
775
 
776
  return image_path
777
 
 
778
  dynamic_gpu_duration.zerogpu = True
779
  sd_gen_generate_pipeline.zerogpu = True
 
 
780
 
781
  from pathlib import Path
782
  from PIL import Image
783
  import random, json
784
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
785
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
786
- get_valid_lora_path, get_valid_lora_wt, get_lora_info, CIVITAI_SORT, CIVITAI_PERIOD,
787
- normalize_prompt_list, get_civitai_info, search_lora_on_civitai, translate_to_en)
 
 
788
 
789
- sd_gen = GuiSD()
790
  #@spaces.GPU
791
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
792
  model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
@@ -796,12 +680,72 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
796
  import numpy as np
797
  MAX_SEED = np.iinfo(np.int32).max
798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
  load_lora_cpu = False
800
  verbose_info = False
801
  gpu_duration = 59
802
 
803
  images: list[tuple[PIL.Image.Image, str | None]] = []
804
- info: str = ""
805
  progress(0, desc="Preparing...")
806
 
807
  if randomize_seed:
@@ -828,7 +772,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
828
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0])
829
  progress(1, desc="Model loaded.")
830
  progress(0, desc="Starting Inference...")
831
- images, info = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
832
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
833
  lora4, lora4_wt, lora5, lora5_wt, sampler,
834
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
@@ -1008,14 +952,14 @@ def update_lora_dict(path: str):
1008
  def download_lora(dl_urls: str):
1009
  global loras_url_to_path_dict
1010
  dl_path = ""
1011
- before = get_local_model_list(directory_loras)
1012
  urls = []
1013
  for url in [url.strip() for url in dl_urls.split(',')]:
1014
- local_path = f"{directory_loras}/{url.split('/')[-1]}"
1015
  if not Path(local_path).exists():
1016
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
1017
  urls.append(url)
1018
- after = get_local_model_list(directory_loras)
1019
  new_files = list_sub(after, before)
1020
  i = 0
1021
  for file in new_files:
 
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
4
+ from constants import (
5
+ PREPROCESSOR_CONTROLNET,
6
+ TASK_STABLEPY,
7
+ TASK_MODEL_LIST,
8
+ UPSCALER_DICT_GUI,
9
+ UPSCALER_KEYS,
10
+ PROMPT_W_OPTIONS,
11
+ WARNING_MSG_VAE,
12
+ SDXL_TASK,
13
+ MODEL_TYPE_TASK,
14
+ POST_PROCESSING_SAMPLER,
15
+
16
+ )
17
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
 
18
  import torch
19
  import re
 
20
  from stablepy import (
 
 
 
 
 
21
  scheduler_names,
 
22
  IP_ADAPTERS_SD,
23
  IP_ADAPTERS_SDXL,
 
 
 
 
24
  )
25
  import time
26
  from PIL import ImageFile
27
+ from utils import (
28
+ get_model_list,
29
+ extract_parameters,
30
+ get_model_type,
31
+ extract_exif_data,
32
+ create_mask_now,
33
+ download_diffuser_repo,
34
+ progress_step_bar,
35
+ html_template_message,
36
+ )
37
+ from datetime import datetime
38
+ import gradio as gr
39
+ import logging
40
+ import diffusers
41
+ import warnings
42
+ from stablepy import logger
43
+ # import urllib.parse
44
 
45
  ImageFile.LOAD_TRUNCATED_IMAGES = True
46
+ # os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
47
  print(os.getenv("SPACES_ZERO_GPU"))
48
 
49
+ ## BEGIN MOD
50
  import gradio as gr
51
  import logging
52
  logging.getLogger("diffusers").setLevel(logging.ERROR)
 
57
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
58
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
59
  from stablepy import logger
60
+ logger.setLevel(logging.DEBUG)
61
 
62
  from env import (
63
+ HF_TOKEN, HF_READ_TOKEN, # to use only for private repos
64
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
65
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
66
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
67
+ DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS,
68
+ DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL,
69
+ LOAD_DIFFUSERS_FORMAT_MODEL, DOWNLOAD_MODEL_LIST, DOWNLOAD_LORA_LIST,
70
+ DOWNLOAD_VAE_LIST, DOWNLOAD_EMBEDS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
72
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
73
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
74
 
75
  # - **Download Models**
76
+ download_model = ", ".join(DOWNLOAD_MODEL_LIST)
77
  # - **Download VAEs**
78
+ download_vae = ", ".join(DOWNLOAD_VAE_LIST)
79
  # - **Download LoRAs**
80
+ download_lora = ", ".join(DOWNLOAD_LORA_LIST)
81
 
82
+ #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, DIRECTORY_LORAS, True)
83
+ download_private_repo(HF_VAE_PRIVATE_REPO, DIRECTORY_VAES, False)
84
 
85
+ load_diffusers_format_model = list_uniq(LOAD_DIFFUSERS_FORMAT_MODEL + get_model_id_list())
86
  ## END MOD
87
 
88
  # Download stuffs
89
  for url in [url.strip() for url in download_model.split(',')]:
90
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
91
+ download_things(DIRECTORY_MODELS, url, HF_TOKEN, CIVITAI_API_KEY)
92
  for url in [url.strip() for url in download_vae.split(',')]:
93
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
94
+ download_things(DIRECTORY_VAES, url, HF_TOKEN, CIVITAI_API_KEY)
95
  for url in [url.strip() for url in download_lora.split(',')]:
96
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
97
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
98
 
99
  # Download Embeddings
100
+ for url_embed in DOWNLOAD_EMBEDS:
101
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
102
+ download_things(DIRECTORY_EMBEDS, url_embed, HF_TOKEN, CIVITAI_API_KEY)
103
 
104
  # Build list models
105
+ embed_list = get_model_list(DIRECTORY_EMBEDS)
106
+ model_list = get_model_list(DIRECTORY_MODELS)
107
  model_list = load_diffusers_format_model + model_list
108
+
109
  ## BEGIN MOD
110
  lora_model_list = get_lora_model_list()
111
+ vae_model_list = get_model_list(DIRECTORY_VAES)
112
  vae_model_list.insert(0, "None")
113
 
114
+ #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_SDXL, False)
115
+ #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_POSITIVE_SDXL, False)
116
+ embed_sdxl_list = get_model_list(DIRECTORY_EMBEDS_SDXL) + get_model_list(DIRECTORY_EMBEDS_POSITIVE_SDXL)
117
 
118
  def get_embed_list(pipeline_name):
119
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
 
121
 
122
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  ## BEGIN MOD
125
  class GuiSD:
126
+ def __init__(self, stream=True):
127
  self.model = None
128
+ self.status_loading = False
129
+ self.sleep_loading = 4
130
+ self.last_load = datetime.now()
 
 
 
 
 
 
 
 
 
131
 
132
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
133
  #progress(0, desc="Start inference...")
 
141
  return img
142
 
143
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
 
 
 
144
  vae_model = vae_model if vae_model != "None" else None
145
  model_type = get_model_type(model_name)
146
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
147
+
148
+ if not os.path.exists(model_name):
149
+ _ = download_diffuser_repo(
150
+ repo_name=model_name,
151
+ model_type=model_type,
152
+ revision="main",
153
+ token=True,
154
+ )
155
+
156
+ for i in range(68):
157
+ if not self.status_loading:
158
+ self.status_loading = True
159
+ if i > 0:
160
+ time.sleep(self.sleep_loading)
161
+ print("Previous model ops...")
162
+ break
163
+ time.sleep(0.5)
164
+ print(f"Waiting queue {i}")
165
+ yield "Waiting queue"
166
+
167
+ self.status_loading = True
168
+
169
+ yield f"Loading model: {model_name}"
170
 
171
  if vae_model:
172
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
173
  if model_type != vae_type:
174
+ gr.Warning(WARNING_MSG_VAE)
175
 
176
+ print("Loading model...")
177
+
178
+ try:
179
+ start_time = time.time()
180
+
181
+ if self.model is None:
182
+ self.model = Model_Diffusers(
183
+ base_model_id=model_name,
184
+ task_name=TASK_STABLEPY[task],
185
+ vae_model=vae_model,
186
+ type_model_precision=dtype_model,
187
+ retain_task_model_in_cache=False,
188
+ device="cpu",
189
+ )
190
+ else:
191
+
192
+ if self.model.base_model_id != model_name:
193
+ load_now_time = datetime.now()
194
+ elapsed_time = max((load_now_time - self.last_load).total_seconds(), 0)
195
+
196
+ if elapsed_time <= 8:
197
+ print("Waiting for the previous model's time ops...")
198
+ time.sleep(8-elapsed_time)
199
+
200
+ self.model.device = torch.device("cpu")
201
+ self.model.load_pipe(
202
+ model_name,
203
+ task_name=TASK_STABLEPY[task],
204
+ vae_model=vae_model,
205
+ type_model_precision=dtype_model,
206
+ retain_task_model_in_cache=False,
207
+ )
208
+
209
+ end_time = time.time()
210
+ self.sleep_loading = max(min(int(end_time - start_time), 10), 4)
211
+ except Exception as e:
212
+ self.last_load = datetime.now()
213
+ self.status_loading = False
214
+ self.sleep_loading = 4
215
+ raise e
216
+
217
+ self.last_load = datetime.now()
218
+ self.status_loading = False
219
+
220
+ yield f"Model loaded: {model_name}"
221
 
222
  #@spaces.GPU
223
  @torch.inference_mode()
 
325
  mode_ip2,
326
  scale_ip2,
327
  pag_scale,
 
328
  ):
329
+ info_state = html_template_message("Navigating latent space...")
330
+ yield info_state, gr.update(), gr.update()
331
+
332
  vae_model = vae_model if vae_model != "None" else None
333
  loras_list = [lora1, lora2, lora3, lora4, lora5]
334
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
335
  msg_lora = ""
336
 
 
 
337
  ## BEGIN MOD
338
+ loras_list = [s if s else "None" for s in loras_list]
339
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
340
  global lora_model_list
341
  lora_model_list = get_lora_model_list()
342
  ## END MOD
343
 
344
+ print("Config model:", model_name, vae_model, loras_list)
345
+
346
  task = TASK_STABLEPY[task]
347
 
348
  params_ip_img = []
 
365
  params_ip_mode.append(modeip)
366
  params_ip_scale.append(scaleip)
367
 
368
+ concurrency = 5
369
+ self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
370
+
371
  if task != "txt2img" and not image_control:
372
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
373
 
 
439
  "high_threshold": high_threshold,
440
  "value_threshold": value_threshold,
441
  "distance_threshold": distance_threshold,
442
+ "lora_A": lora1 if lora1 != "None" else None,
443
  "lora_scale_A": lora_scale1,
444
+ "lora_B": lora2 if lora2 != "None" else None,
445
  "lora_scale_B": lora_scale2,
446
+ "lora_C": lora3 if lora3 != "None" else None,
447
  "lora_scale_C": lora_scale3,
448
+ "lora_D": lora4 if lora4 != "None" else None,
449
  "lora_scale_D": lora_scale4,
450
+ "lora_E": lora5 if lora5 != "None" else None,
451
  "lora_scale_E": lora_scale5,
452
  ## BEGIN MOD
453
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
 
497
  }
498
 
499
  self.model.device = torch.device("cuda:0")
500
+ if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
501
  self.model.pipe.transformer.to(self.model.device)
502
  print("transformer to cuda")
503
 
504
+ #return self.infer_short(self.model, pipe_params), info_state
505
+
506
+ actual_progress = 0
507
+ info_images = gr.update()
508
+ for img, seed, image_path, metadata in self.model(**pipe_params):
509
+ info_state = progress_step_bar(actual_progress, steps)
510
+ actual_progress += concurrency
511
+ if image_path:
512
+ info_images = f"Seeds: {str(seed)}"
513
+ if vae_msg:
514
+ info_images = info_images + "<br>" + vae_msg
515
+
516
+ if "Cannot copy out of meta tensor; no data!" in self.model.last_lora_error:
517
+ msg_ram = "Unable to process the LoRAs due to high RAM usage; please try again later."
518
+ print(msg_ram)
519
+ msg_lora += f"<br>{msg_ram}"
520
+
521
+ for status, lora in zip(self.model.lora_status, self.model.lora_memory):
522
+ if status:
523
+ msg_lora += f"<br>Loaded: {lora}"
524
+ elif status is not None:
525
+ msg_lora += f"<br>Error with: {lora}"
526
+
527
+ if msg_lora:
528
+ info_images += msg_lora
529
+
530
+ info_images = info_images + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
531
+
532
+ download_links = "<br>".join(
533
+ [
534
+ f'<a href="{path.replace("/images/", "/file=/home/user/app/images/")}" download="{os.path.basename(path)}">Download Image {i + 1}</a>'
535
+ for i, path in enumerate(image_path)
536
+ ]
537
+ )
538
+ if save_generated_images:
539
+ info_images += f"<br>{download_links}"
540
+ ## BEGIN MOD
541
+ if not isinstance(img, list): img = [img]
542
+ img = save_images(img, metadata)
543
+ img = [(i, None) for i in img]
544
  ## END MOD
545
+ info_state = "COMPLETE"
546
+
547
+ yield info_state, img, info_images
548
+ #return info_state, img, info_images
549
 
550
  def dynamic_gpu_duration(func, duration, *args):
551
 
552
  @spaces.GPU(duration=duration)
553
  def wrapped_func():
554
+ yield from func(*args)
555
 
556
  return wrapped_func()
557
 
 
568
  load_lora_cpu = args[-3]
569
  generation_args = args[:-3]
570
  lora_list = [
571
+ None if item == "None" or item == "" else item # MOD
572
  for item in [args[7], args[9], args[11], args[13], args[15]]
573
  ]
574
  lora_status = [None] * 5
 
577
  if load_lora_cpu:
578
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
579
 
580
+ if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
581
+ yield msg_load_lora, gr.update(), gr.update()
582
 
583
  # Load lora in CPU
584
  if load_lora_cpu:
 
604
  )
605
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
606
 
607
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time.\nModel: {sd_gen.model.base_model_id}"
608
+ if verbose_arg:
609
  gr.Info(msg_request)
610
  print(msg_request)
611
+ yield msg_request.replace("\n", "<br>"), gr.update(), gr.update()
 
612
 
613
  start_time = time.time()
614
 
615
+ # yield from sd_gen.generate_pipeline(*generation_args)
616
+ yield from dynamic_gpu_duration(
617
+ #return dynamic_gpu_duration(
618
  sd_gen.generate_pipeline,
619
  gpu_duration_arg,
620
  *generation_args,
621
  )
622
 
623
  end_time = time.time()
624
+ execution_time = end_time - start_time
625
+ msg_task_complete = (
626
+ f"GPU task complete in: {int(round(execution_time, 0) + 1)} seconds"
627
+ )
628
 
629
  if verbose_arg:
 
 
 
 
630
  gr.Info(msg_task_complete)
631
  print(msg_task_complete)
632
 
633
+ yield msg_task_complete, gr.update(), gr.update()
 
 
 
 
634
 
 
 
 
635
 
636
+ @spaces.GPU(duration=15)
 
 
 
 
 
637
  def esrgan_upscale(image, upscaler_name, upscaler_size):
638
  if image is None: return None
639
 
 
655
 
656
  return image_path
657
 
658
+
659
  dynamic_gpu_duration.zerogpu = True
660
  sd_gen_generate_pipeline.zerogpu = True
661
+ sd_gen = GuiSD()
662
+
663
 
664
  from pathlib import Path
665
  from PIL import Image
666
  import random, json
667
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
668
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
669
+ get_valid_lora_path, get_valid_lora_wt, get_lora_info, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
670
+ normalize_prompt_list, get_civitai_info, search_lora_on_civitai, translate_to_en, get_t2i_model_info, get_civitai_tag, save_image_history)
671
+
672
+
673
 
 
674
  #@spaces.GPU
675
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
676
  model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
 
680
  import numpy as np
681
  MAX_SEED = np.iinfo(np.int32).max
682
 
683
+ image_previews = True
684
+ load_lora_cpu = False
685
+ verbose_info = False
686
+ gpu_duration = 59
687
+
688
+ images: list[tuple[PIL.Image.Image, str | None]] = []
689
+ progress(0, desc="Preparing...")
690
+
691
+ if randomize_seed:
692
+ seed = random.randint(0, MAX_SEED)
693
+
694
+ generator = torch.Generator().manual_seed(seed).seed()
695
+
696
+ if translate:
697
+ prompt = translate_to_en(prompt)
698
+ negative_prompt = translate_to_en(prompt)
699
+
700
+ prompt, negative_prompt = insert_model_recom_prompt(prompt, negative_prompt, model_name)
701
+ progress(0.5, desc="Preparing...")
702
+ lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt = \
703
+ set_prompt_loras(prompt, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt)
704
+ lora1 = get_valid_lora_path(lora1)
705
+ lora2 = get_valid_lora_path(lora2)
706
+ lora3 = get_valid_lora_path(lora3)
707
+ lora4 = get_valid_lora_path(lora4)
708
+ lora5 = get_valid_lora_path(lora5)
709
+ progress(1, desc="Preparation completed. Starting inference...")
710
+
711
+ progress(0, desc="Loading model...")
712
+ for _ in sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0]):
713
+ pass
714
+ progress(1, desc="Model loaded.")
715
+ progress(0, desc="Starting Inference...")
716
+ for info_state, stream_images, info_images in sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
717
+ guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
718
+ lora4, lora4_wt, lora5, lora5_wt, sampler,
719
+ height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
720
+ None, None, None, 0.35, 100, 200, 0.1, 0.1, 1.0, 0., 1., False, "Classic", None,
721
+ 1.0, 100, 10, 30, 0.55, "Use same sampler", "", "",
722
+ False, True, 1, True, False, image_previews, False, False, "./images", False, False, False, True, 1, 0.55,
723
+ False, False, False, True, False, "Use same sampler", False, "", "", 0.35, True, True, False, 4, 4, 32,
724
+ False, "", "", 0.35, True, True, False, 4, 4, 32,
725
+ True, None, None, "plus_face", "original", 0.7, None, None, "base", "style", 0.7, 0.0,
726
+ load_lora_cpu, verbose_info, gpu_duration
727
+ ):
728
+ images = stream_images if isinstance(stream_images, list) else images
729
+ progress(1, desc="Inference completed.")
730
+ output_image = images[0][0] if images else None
731
+
732
+ return output_image
733
+
734
+ #@spaces.GPU
735
+ def __infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
736
+ model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
737
+ lora3 = None, lora3_wt = 1.0, lora4 = None, lora4_wt = 1.0, lora5 = None, lora5_wt = 1.0,
738
+ sampler = "Euler a", vae = None, translate=True, progress=gr.Progress(track_tqdm=True)):
739
+ import PIL
740
+ import numpy as np
741
+ MAX_SEED = np.iinfo(np.int32).max
742
+
743
  load_lora_cpu = False
744
  verbose_info = False
745
  gpu_duration = 59
746
 
747
  images: list[tuple[PIL.Image.Image, str | None]] = []
748
+ info_state = info_images = ""
749
  progress(0, desc="Preparing...")
750
 
751
  if randomize_seed:
 
772
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0])
773
  progress(1, desc="Model loaded.")
774
  progress(0, desc="Starting Inference...")
775
+ info_state, images, info_images = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
776
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
777
  lora4, lora4_wt, lora5, lora5_wt, sampler,
778
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
 
952
  def download_lora(dl_urls: str):
953
  global loras_url_to_path_dict
954
  dl_path = ""
955
+ before = get_local_model_list(DIRECTORY_LORAS)
956
  urls = []
957
  for url in [url.strip() for url in dl_urls.split(',')]:
958
+ local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
959
  if not Path(local_path).exists():
960
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
961
  urls.append(url)
962
+ after = get_local_model_list(DIRECTORY_LORAS)
963
  new_files = list_sub(after, before)
964
  i = 0
965
  for file in new_files:
env.py CHANGED
@@ -2,10 +2,10 @@ import os
2
 
3
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
4
  HF_TOKEN = os.environ.get("HF_TOKEN")
5
- hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
6
 
7
  # - **List Models**
8
- load_diffusers_format_model = [
9
  'votepurchase/animagine-xl-3.1',
10
  'votepurchase/NSFW-GEN-ANIME-v2',
11
  'votepurchase/kivotos-xl-2.0',
@@ -138,11 +138,11 @@ HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule
138
 
139
 
140
  # - **Download Models**
141
- download_model_list = [
142
  ]
143
 
144
  # - **Download VAEs**
145
- download_vae_list = [
146
  'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
147
  'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
148
  "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
@@ -151,29 +151,26 @@ download_vae_list = [
151
  ]
152
 
153
  # - **Download LoRAs**
154
- download_lora_list = [
155
  ]
156
 
157
  # Download Embeddings
158
- download_embeds = [
159
  'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
160
  'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
161
  'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
162
  ]
163
 
164
- directory_models = 'models'
165
- os.makedirs(directory_models, exist_ok=True)
166
- directory_loras = 'loras'
167
- os.makedirs(directory_loras, exist_ok=True)
168
- directory_vaes = 'vaes'
169
- os.makedirs(directory_vaes, exist_ok=True)
170
- directory_embeds = 'embedings'
171
- os.makedirs(directory_embeds, exist_ok=True)
172
 
173
- directory_embeds_sdxl = 'embedings_xl'
174
- os.makedirs(directory_embeds_sdxl, exist_ok=True)
175
- directory_embeds_positive_sdxl = 'embedings_xl/positive'
176
- os.makedirs(directory_embeds_positive_sdxl, exist_ok=True)
177
 
178
  HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
179
  HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
 
2
 
3
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
4
  HF_TOKEN = os.environ.get("HF_TOKEN")
5
+ HF_READ_TOKEN = os.environ.get('HF_READ_TOKEN') # only use for private repo
6
 
7
  # - **List Models**
8
+ LOAD_DIFFUSERS_FORMAT_MODEL = [
9
  'votepurchase/animagine-xl-3.1',
10
  'votepurchase/NSFW-GEN-ANIME-v2',
11
  'votepurchase/kivotos-xl-2.0',
 
138
 
139
 
140
  # - **Download Models**
141
+ DOWNLOAD_MODEL_LIST = [
142
  ]
143
 
144
  # - **Download VAEs**
145
+ DOWNLOAD_VAE_LIST = [
146
  'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
147
  'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
148
  "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
 
151
  ]
152
 
153
  # - **Download LoRAs**
154
+ DOWNLOAD_LORA_LIST = [
155
  ]
156
 
157
  # Download Embeddings
158
+ DOWNLOAD_EMBEDS = [
159
  'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
160
  'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
161
  'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
162
  ]
163
 
164
+ DIRECTORY_MODELS = 'models'
165
+ DIRECTORY_LORAS = 'loras'
166
+ DIRECTORY_VAES = 'vaes'
167
+ DIRECTORY_EMBEDS = 'embedings'
168
+ DIRECTORY_EMBEDS_SDXL = 'embedings_xl'
169
+ DIRECTORY_EMBEDS_POSITIVE_SDXL = 'embedings_xl/positive'
 
 
170
 
171
+ directories = [DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS, DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL]
172
+ for directory in directories:
173
+ os.makedirs(directory, exist_ok=True)
 
174
 
175
  HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
176
  HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
llmdolphin.py CHANGED
@@ -1,5 +1,9 @@
1
  import spaces
2
  import gradio as gr
 
 
 
 
3
  from llama_cpp import Llama
4
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
5
  from llama_cpp_agent.providers import LlamaCppPythonProvider
@@ -7,7 +11,6 @@ from llama_cpp_agent.chat_history import BasicChatHistory
7
  from llama_cpp_agent.chat_history.messages import Roles
8
  from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
9
  import wrapt_timeout_decorator
10
- from pathlib import Path
11
  from llama_cpp_agent.messages_formatter import MessagesFormatter
12
  from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
13
 
@@ -19,6 +22,7 @@ llm_models = {
19
  #"": ["", MessagesFormatterType.OPEN_CHAT],
20
  #"": ["", MessagesFormatterType.CHATML],
21
  #"": ["", MessagesFormatterType.PHI_3],
 
22
  "mn-12b-lyra-v2a1-q5_k_m.gguf": ["HalleyStarbun/MN-12B-Lyra-v2a1-Q5_K_M-GGUF", MessagesFormatterType.CHATML],
23
  "L3-8B-Tamamo-v1.i1-Q5_K_M.gguf": ["mradermacher/L3-8B-Tamamo-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
24
  "MN-Chinofun-12B-2.i1-Q4_K_M.gguf": ["mradermacher/MN-Chinofun-12B-2-i1-GGUF", MessagesFormatterType.MISTRAL],
@@ -68,6 +72,19 @@ llm_models = {
68
  "ChatWaifu_22B_v2.0_preview.Q4_K_S.gguf": ["mradermacher/ChatWaifu_22B_v2.0_preview-GGUF", MessagesFormatterType.MISTRAL],
69
  "ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
70
  "ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  "hermes-llama3-roleplay-1000-v2.Q5_K_M.gguf": ["mradermacher/hermes-llama3-roleplay-1000-v2-GGUF", MessagesFormatterType.LLAMA_3],
72
  "hermes-stheno-8B-v0.1.i1-Q5_K_M.gguf": ["mradermacher/hermes-stheno-8B-v0.1-i1-GGUF", MessagesFormatterType.LLAMA_3],
73
  "qwen-carpmuscle-r-v0.3.Q4_K_M.gguf": ["mradermacher/qwen-carpmuscle-r-v0.3-GGUF", MessagesFormatterType.OPEN_CHAT],
@@ -832,6 +849,7 @@ llm_languages = ["English", "Japanese", "Chinese", "Korean", "Spanish", "Portugu
832
  llm_models_tupled_list = []
833
  default_llm_model_filename = list(llm_models.keys())[0]
834
  override_llm_format = None
 
835
 
836
 
837
  def to_list(s):
@@ -844,7 +862,6 @@ def list_uniq(l):
844
 
845
  @wrapt_timeout_decorator.timeout(dec_timeout=3.5)
846
  def to_list_ja(s):
847
- import re
848
  s = re.sub(r'[、。]', ',', s)
849
  return [x.strip() for x in s.split(",") if not s == ""]
850
 
@@ -859,7 +876,6 @@ def is_japanese(s):
859
 
860
 
861
  def update_llm_model_tupled_list():
862
- from pathlib import Path
863
  global llm_models_tupled_list
864
  llm_models_tupled_list = []
865
  for k, v in llm_models.items():
@@ -876,7 +892,6 @@ def update_llm_model_tupled_list():
876
 
877
 
878
  def download_llm_models():
879
- from huggingface_hub import hf_hub_download
880
  global llm_models_tupled_list
881
  llm_models_tupled_list = []
882
  for k, v in llm_models.items():
@@ -890,7 +905,6 @@ def download_llm_models():
890
 
891
 
892
  def download_llm_model(filename):
893
- from huggingface_hub import hf_hub_download
894
  if not filename in llm_models.keys(): return default_llm_model_filename
895
  try:
896
  hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
@@ -951,8 +965,6 @@ def get_dolphin_model_format(filename):
951
 
952
 
953
  def add_dolphin_models(query, format_name):
954
- import re
955
- from huggingface_hub import HfApi
956
  global llm_models
957
  api = HfApi()
958
  add_models = {}
@@ -964,20 +976,19 @@ def add_dolphin_models(query, format_name):
964
  if s and "" in s: s.remove("")
965
  if len(s) == 1:
966
  repo = s[0]
967
- if not api.repo_exists(repo_id = repo): return gr.update(visible=True)
968
  files = api.list_repo_files(repo_id = repo)
969
  for file in files:
970
  if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
971
  elif len(s) >= 2:
972
  repo = s[0]
973
  filename = s[1]
974
- if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update(visible=True)
975
  add_models[filename] = [repo, format]
976
- else: return gr.update(visible=True)
977
  except Exception as e:
978
  print(e)
979
- return gr.update(visible=True)
980
- #print(add_models)
981
  llm_models = (llm_models | add_models).copy()
982
  update_llm_model_tupled_list()
983
  choices = get_dolphin_models()
@@ -1177,7 +1188,6 @@ Output should be enclosed in //GENBEGIN//:// and //://GENEND//. The text to be g
1177
 
1178
 
1179
  def get_dolphin_sysprompt():
1180
- import re
1181
  prompt = re.sub('<LANGUAGE>', dolphin_output_language, dolphin_system_prompt.get(dolphin_sysprompt_mode, ""))
1182
  return prompt
1183
 
@@ -1207,11 +1217,11 @@ def select_dolphin_language(lang: str):
1207
 
1208
  @wrapt_timeout_decorator.timeout(dec_timeout=5.0)
1209
  def get_raw_prompt(msg: str):
1210
- import re
1211
  m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
1212
  return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""
1213
 
1214
 
 
1215
  @spaces.GPU(duration=60)
1216
  def dolphin_respond(
1217
  message: str,
@@ -1225,87 +1235,92 @@ def dolphin_respond(
1225
  repeat_penalty: float = 1.1,
1226
  progress=gr.Progress(track_tqdm=True),
1227
  ):
1228
- from pathlib import Path
1229
- progress(0, desc="Processing...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1230
 
1231
- if override_llm_format:
1232
- chat_template = override_llm_format
1233
- else:
1234
- chat_template = llm_models[model][1]
1235
-
1236
- llm = Llama(
1237
- model_path=str(Path(f"{llm_models_dir}/{model}")),
1238
- flash_attn=True,
1239
- n_gpu_layers=81, # 81
1240
- n_batch=1024,
1241
- n_ctx=8192, #8192
1242
- )
1243
- provider = LlamaCppPythonProvider(llm)
1244
-
1245
- agent = LlamaCppAgent(
1246
- provider,
1247
- system_prompt=f"{system_message}",
1248
- predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1249
- custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1250
- debug_output=False
1251
- )
1252
-
1253
- settings = provider.get_provider_default_settings()
1254
- settings.temperature = temperature
1255
- settings.top_k = top_k
1256
- settings.top_p = top_p
1257
- settings.max_tokens = max_tokens
1258
- settings.repeat_penalty = repeat_penalty
1259
- settings.stream = True
1260
-
1261
- messages = BasicChatHistory()
1262
-
1263
- for msn in history:
1264
- user = {
1265
- 'role': Roles.user,
1266
- 'content': msn[0]
1267
- }
1268
- assistant = {
1269
- 'role': Roles.assistant,
1270
- 'content': msn[1]
1271
- }
1272
- messages.add_message(user)
1273
- messages.add_message(assistant)
1274
-
1275
- stream = agent.get_chat_response(
1276
- message,
1277
- llm_sampling_settings=settings,
1278
- chat_history=messages,
1279
- returns_streaming_generator=True,
1280
- print_output=False
1281
- )
1282
-
1283
- progress(0.5, desc="Processing...")
1284
-
1285
- outputs = ""
1286
- for output in stream:
1287
- outputs += output
1288
- yield [(outputs, None)]
1289
 
1290
 
1291
  def dolphin_parse(
1292
  history: list[tuple[str, str]],
1293
  ):
1294
- if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
1295
- return "", gr.update(visible=True), gr.update(visible=True)
1296
  try:
 
 
1297
  msg = history[-1][0]
1298
  raw_prompt = get_raw_prompt(msg)
1299
- except Exception:
1300
- return "", gr.update(visible=True), gr.update(visible=True)
1301
- prompts = []
1302
- if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1303
- prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"])
1304
- else:
1305
- prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"])
1306
- return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True)
 
1307
 
1308
 
 
1309
  @spaces.GPU(duration=60)
1310
  def dolphin_respond_auto(
1311
  message: str,
@@ -1319,94 +1334,100 @@ def dolphin_respond_auto(
1319
  repeat_penalty: float = 1.1,
1320
  progress=gr.Progress(track_tqdm=True),
1321
  ):
1322
- #if not is_japanese(message): return [(None, None)]
1323
- from pathlib import Path
1324
- progress(0, desc="Processing...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1325
 
1326
- if override_llm_format:
1327
- chat_template = override_llm_format
1328
- else:
1329
- chat_template = llm_models[model][1]
1330
-
1331
- llm = Llama(
1332
- model_path=str(Path(f"{llm_models_dir}/{model}")),
1333
- flash_attn=True,
1334
- n_gpu_layers=81, # 81
1335
- n_batch=1024,
1336
- n_ctx=8192, #8192
1337
- )
1338
- provider = LlamaCppPythonProvider(llm)
1339
-
1340
- agent = LlamaCppAgent(
1341
- provider,
1342
- system_prompt=f"{system_message}",
1343
- predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1344
- custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1345
- debug_output=False
1346
- )
1347
-
1348
- settings = provider.get_provider_default_settings()
1349
- settings.temperature = temperature
1350
- settings.top_k = top_k
1351
- settings.top_p = top_p
1352
- settings.max_tokens = max_tokens
1353
- settings.repeat_penalty = repeat_penalty
1354
- settings.stream = True
1355
-
1356
- messages = BasicChatHistory()
1357
-
1358
- for msn in history:
1359
- user = {
1360
- 'role': Roles.user,
1361
- 'content': msn[0]
1362
- }
1363
- assistant = {
1364
- 'role': Roles.assistant,
1365
- 'content': msn[1]
1366
- }
1367
- messages.add_message(user)
1368
- messages.add_message(assistant)
1369
-
1370
- progress(0, desc="Translating...")
1371
- stream = agent.get_chat_response(
1372
- message,
1373
- llm_sampling_settings=settings,
1374
- chat_history=messages,
1375
- returns_streaming_generator=True,
1376
- print_output=False
1377
- )
1378
-
1379
- progress(0.5, desc="Processing...")
1380
-
1381
- outputs = ""
1382
- for output in stream:
1383
- outputs += output
1384
- yield [(outputs, None)]
1385
 
1386
 
1387
  def dolphin_parse_simple(
1388
  message: str,
1389
  history: list[tuple[str, str]],
1390
  ):
1391
- #if not is_japanese(message): return message
1392
- if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
1393
  try:
 
 
1394
  msg = history[-1][0]
1395
  raw_prompt = get_raw_prompt(msg)
1396
- except Exception:
 
 
 
 
 
 
 
1397
  return ""
1398
- prompts = []
1399
- if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1400
- prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
1401
- else:
1402
- prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
1403
- return ", ".join(prompts)
1404
 
1405
 
1406
  # https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground
1407
  import cv2
1408
  cv2.setNumThreads(1)
1409
 
 
 
1410
  @spaces.GPU()
1411
  def respond_playground(
1412
  message,
@@ -1419,47 +1440,47 @@ def respond_playground(
1419
  top_k,
1420
  repeat_penalty,
1421
  ):
1422
- if override_llm_format:
1423
- chat_template = override_llm_format
1424
- else:
1425
- chat_template = llm_models[model][1]
1426
-
1427
- llm = Llama(
1428
- model_path=str(Path(f"{llm_models_dir}/{model}")),
1429
- flash_attn=True,
1430
- n_gpu_layers=81, # 81
1431
- n_batch=1024,
1432
- n_ctx=8192, #8192
1433
- )
1434
- provider = LlamaCppPythonProvider(llm)
1435
-
1436
- agent = LlamaCppAgent(
1437
- provider,
1438
- system_prompt=f"{system_message}",
1439
- predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1440
- custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1441
- debug_output=False
1442
- )
1443
-
1444
- settings = provider.get_provider_default_settings()
1445
- settings.temperature = temperature
1446
- settings.top_k = top_k
1447
- settings.top_p = top_p
1448
- settings.max_tokens = max_tokens
1449
- settings.repeat_penalty = repeat_penalty
1450
- settings.stream = True
1451
-
1452
- messages = BasicChatHistory()
1453
-
1454
- # Add user and assistant messages to the history
1455
- for msn in history:
1456
- user = {'role': Roles.user, 'content': msn[0]}
1457
- assistant = {'role': Roles.assistant, 'content': msn[1]}
1458
- messages.add_message(user)
1459
- messages.add_message(assistant)
1460
-
1461
- # Stream the response
1462
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1463
  stream = agent.get_chat_response(
1464
  message,
1465
  llm_sampling_settings=settings,
@@ -1473,4 +1494,5 @@ def respond_playground(
1473
  outputs += output
1474
  yield outputs
1475
  except Exception as e:
1476
- yield f"Error during response generation: {str(e)}"
 
 
1
  import spaces
2
  import gradio as gr
3
+ from pathlib import Path
4
+ import re
5
+ import torch
6
+ from huggingface_hub import hf_hub_download, HfApi
7
  from llama_cpp import Llama
8
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
9
  from llama_cpp_agent.providers import LlamaCppPythonProvider
 
11
  from llama_cpp_agent.chat_history.messages import Roles
12
  from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
13
  import wrapt_timeout_decorator
 
14
  from llama_cpp_agent.messages_formatter import MessagesFormatter
15
  from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
16
 
 
22
  #"": ["", MessagesFormatterType.OPEN_CHAT],
23
  #"": ["", MessagesFormatterType.CHATML],
24
  #"": ["", MessagesFormatterType.PHI_3],
25
+ #"": ["", MessagesFormatterType.GEMMA_2],
26
  "mn-12b-lyra-v2a1-q5_k_m.gguf": ["HalleyStarbun/MN-12B-Lyra-v2a1-Q5_K_M-GGUF", MessagesFormatterType.CHATML],
27
  "L3-8B-Tamamo-v1.i1-Q5_K_M.gguf": ["mradermacher/L3-8B-Tamamo-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
28
  "MN-Chinofun-12B-2.i1-Q4_K_M.gguf": ["mradermacher/MN-Chinofun-12B-2-i1-GGUF", MessagesFormatterType.MISTRAL],
 
72
  "ChatWaifu_22B_v2.0_preview.Q4_K_S.gguf": ["mradermacher/ChatWaifu_22B_v2.0_preview-GGUF", MessagesFormatterType.MISTRAL],
73
  "ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
74
  "ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
75
+ "Magnum_Dark_Madness_12b.Q4_K_S.gguf": ["mradermacher/Magnum_Dark_Madness_12b-GGUF", MessagesFormatterType.MISTRAL],
76
+ "Magnum_Lyra_Darkness_12b.Q4_K_M.gguf": ["mradermacher/Magnum_Lyra_Darkness_12b-GGUF", MessagesFormatterType.MISTRAL],
77
+ "Heart_Stolen-8B-task.i1-Q4_K_M.gguf": ["mradermacher/Heart_Stolen-8B-task-i1-GGUF", MessagesFormatterType.LLAMA_3],
78
+ "Magnum_Backyard_Party_12b.Q4_K_M.gguf": ["mradermacher/Magnum_Backyard_Party_12b-GGUF", MessagesFormatterType.MISTRAL],
79
+ "Magnum_Madness-12b.Q4_K_M.gguf": ["mradermacher/Magnum_Madness-12b-GGUF", MessagesFormatterType.MISTRAL],
80
+ "L3.1-Moe-2x8B-v0.2.i1-Q4_K_M.gguf": ["mradermacher/L3.1-Moe-2x8B-v0.2-i1-GGUF", MessagesFormatterType.LLAMA_3],
81
+ "Qwen2.5-14B-Wernicke-DPO.i1-Q4_K_M.gguf": ["mradermacher/Qwen2.5-14B-Wernicke-DPO-i1-GGUF", MessagesFormatterType.OPEN_CHAT],
82
+ "Gemma-2-Ataraxy-v4d-9B.i1-Q4_K_M.gguf": ["mradermacher/Gemma-2-Ataraxy-v4d-9B-i1-GGUF", MessagesFormatterType.GEMMA_2],
83
+ "qwen2.5-14b-megamerge-pt2-q5_k_m.gguf": ["CultriX/Qwen2.5-14B-MegaMerge-pt2-Q5_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
84
+ "quantqwen2-merged-16bit-q4_k_m.gguf": ["davidbzyk/QuantQwen2-merged-16bit-Q4_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
85
+ "Mistral-nemo-ja-rp-v0.2-Q4_K_S.gguf": ["ascktgcc/Mistral-nemo-ja-rp-v0.2-GGUF", MessagesFormatterType.MISTRAL],
86
+ "llama3.1-darkstorm-aspire-8b-q4_k_m.gguf": ["ZeroXClem/Llama3.1-DarkStorm-Aspire-8B-Q4_K_M-GGUF", MessagesFormatterType.LLAMA_3],
87
+ "llama-3-yggdrasil-astralspice-8b-q4_k_m.gguf": ["ZeroXClem/Llama-3-Yggdrasil-AstralSpice-8B-Q4_K_M-GGUF", MessagesFormatterType.LLAMA_3],
88
  "hermes-llama3-roleplay-1000-v2.Q5_K_M.gguf": ["mradermacher/hermes-llama3-roleplay-1000-v2-GGUF", MessagesFormatterType.LLAMA_3],
89
  "hermes-stheno-8B-v0.1.i1-Q5_K_M.gguf": ["mradermacher/hermes-stheno-8B-v0.1-i1-GGUF", MessagesFormatterType.LLAMA_3],
90
  "qwen-carpmuscle-r-v0.3.Q4_K_M.gguf": ["mradermacher/qwen-carpmuscle-r-v0.3-GGUF", MessagesFormatterType.OPEN_CHAT],
 
849
  llm_models_tupled_list = []
850
  default_llm_model_filename = list(llm_models.keys())[0]
851
  override_llm_format = None
852
+ device = "cuda" if torch.cuda.is_available() else "cpu"
853
 
854
 
855
  def to_list(s):
 
862
 
863
  @wrapt_timeout_decorator.timeout(dec_timeout=3.5)
864
  def to_list_ja(s):
 
865
  s = re.sub(r'[、。]', ',', s)
866
  return [x.strip() for x in s.split(",") if not s == ""]
867
 
 
876
 
877
 
878
  def update_llm_model_tupled_list():
 
879
  global llm_models_tupled_list
880
  llm_models_tupled_list = []
881
  for k, v in llm_models.items():
 
892
 
893
 
894
  def download_llm_models():
 
895
  global llm_models_tupled_list
896
  llm_models_tupled_list = []
897
  for k, v in llm_models.items():
 
905
 
906
 
907
  def download_llm_model(filename):
 
908
  if not filename in llm_models.keys(): return default_llm_model_filename
909
  try:
910
  hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
 
965
 
966
 
967
  def add_dolphin_models(query, format_name):
 
 
968
  global llm_models
969
  api = HfApi()
970
  add_models = {}
 
976
  if s and "" in s: s.remove("")
977
  if len(s) == 1:
978
  repo = s[0]
979
+ if not api.repo_exists(repo_id = repo): return gr.update()
980
  files = api.list_repo_files(repo_id = repo)
981
  for file in files:
982
  if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
983
  elif len(s) >= 2:
984
  repo = s[0]
985
  filename = s[1]
986
+ if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update()
987
  add_models[filename] = [repo, format]
988
+ else: return gr.update()
989
  except Exception as e:
990
  print(e)
991
+ return gr.update()
 
992
  llm_models = (llm_models | add_models).copy()
993
  update_llm_model_tupled_list()
994
  choices = get_dolphin_models()
 
1188
 
1189
 
1190
  def get_dolphin_sysprompt():
 
1191
  prompt = re.sub('<LANGUAGE>', dolphin_output_language, dolphin_system_prompt.get(dolphin_sysprompt_mode, ""))
1192
  return prompt
1193
 
 
1217
 
1218
  @wrapt_timeout_decorator.timeout(dec_timeout=5.0)
1219
  def get_raw_prompt(msg: str):
 
1220
  m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
1221
  return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""
1222
 
1223
 
1224
+ @torch.inference_mode()
1225
  @spaces.GPU(duration=60)
1226
  def dolphin_respond(
1227
  message: str,
 
1235
  repeat_penalty: float = 1.1,
1236
  progress=gr.Progress(track_tqdm=True),
1237
  ):
1238
+ try:
1239
+ progress(0, desc="Processing...")
1240
+
1241
+ if override_llm_format:
1242
+ chat_template = override_llm_format
1243
+ else:
1244
+ chat_template = llm_models[model][1]
1245
+
1246
+ llm = Llama(
1247
+ model_path=str(Path(f"{llm_models_dir}/{model}")),
1248
+ flash_attn=True,
1249
+ n_gpu_layers=81, # 81
1250
+ n_batch=1024,
1251
+ n_ctx=8192, #8192
1252
+ )
1253
+ provider = LlamaCppPythonProvider(llm)
1254
+
1255
+ agent = LlamaCppAgent(
1256
+ provider,
1257
+ system_prompt=f"{system_message}",
1258
+ predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1259
+ custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1260
+ debug_output=False
1261
+ )
1262
+
1263
+ settings = provider.get_provider_default_settings()
1264
+ settings.temperature = temperature
1265
+ settings.top_k = top_k
1266
+ settings.top_p = top_p
1267
+ settings.max_tokens = max_tokens
1268
+ settings.repeat_penalty = repeat_penalty
1269
+ settings.stream = True
1270
+
1271
+ messages = BasicChatHistory()
1272
+
1273
+ for msn in history:
1274
+ user = {
1275
+ 'role': Roles.user,
1276
+ 'content': msn[0]
1277
+ }
1278
+ assistant = {
1279
+ 'role': Roles.assistant,
1280
+ 'content': msn[1]
1281
+ }
1282
+ messages.add_message(user)
1283
+ messages.add_message(assistant)
1284
+
1285
+ stream = agent.get_chat_response(
1286
+ message,
1287
+ llm_sampling_settings=settings,
1288
+ chat_history=messages,
1289
+ returns_streaming_generator=True,
1290
+ print_output=False
1291
+ )
1292
+
1293
+ progress(0.5, desc="Processing...")
1294
 
1295
+ outputs = ""
1296
+ for output in stream:
1297
+ outputs += output
1298
+ yield [(outputs, None)]
1299
+ except Exception as e:
1300
+ print(e)
1301
+ yield [("", None)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1302
 
1303
 
1304
  def dolphin_parse(
1305
  history: list[tuple[str, str]],
1306
  ):
 
 
1307
  try:
1308
+ if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
1309
+ return "", gr.update(), gr.update()
1310
  msg = history[-1][0]
1311
  raw_prompt = get_raw_prompt(msg)
1312
+ prompts = []
1313
+ if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1314
+ prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"])
1315
+ else:
1316
+ prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"])
1317
+ return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True)
1318
+ except Exception as e:
1319
+ print(e)
1320
+ return "", gr.update(), gr.update()
1321
 
1322
 
1323
+ @torch.inference_mode()
1324
  @spaces.GPU(duration=60)
1325
  def dolphin_respond_auto(
1326
  message: str,
 
1334
  repeat_penalty: float = 1.1,
1335
  progress=gr.Progress(track_tqdm=True),
1336
  ):
1337
+ try:
1338
+ #if not is_japanese(message): return [(None, None)]
1339
+ progress(0, desc="Processing...")
1340
+
1341
+ if override_llm_format:
1342
+ chat_template = override_llm_format
1343
+ else:
1344
+ chat_template = llm_models[model][1]
1345
+
1346
+ llm = Llama(
1347
+ model_path=str(Path(f"{llm_models_dir}/{model}")),
1348
+ flash_attn=True,
1349
+ n_gpu_layers=81, # 81
1350
+ n_batch=1024,
1351
+ n_ctx=8192, #8192
1352
+ )
1353
+ provider = LlamaCppPythonProvider(llm)
1354
+
1355
+ agent = LlamaCppAgent(
1356
+ provider,
1357
+ system_prompt=f"{system_message}",
1358
+ predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1359
+ custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1360
+ debug_output=False
1361
+ )
1362
+
1363
+ settings = provider.get_provider_default_settings()
1364
+ settings.temperature = temperature
1365
+ settings.top_k = top_k
1366
+ settings.top_p = top_p
1367
+ settings.max_tokens = max_tokens
1368
+ settings.repeat_penalty = repeat_penalty
1369
+ settings.stream = True
1370
+
1371
+ messages = BasicChatHistory()
1372
+
1373
+ for msn in history:
1374
+ user = {
1375
+ 'role': Roles.user,
1376
+ 'content': msn[0]
1377
+ }
1378
+ assistant = {
1379
+ 'role': Roles.assistant,
1380
+ 'content': msn[1]
1381
+ }
1382
+ messages.add_message(user)
1383
+ messages.add_message(assistant)
1384
+
1385
+ progress(0, desc="Translating...")
1386
+ stream = agent.get_chat_response(
1387
+ message,
1388
+ llm_sampling_settings=settings,
1389
+ chat_history=messages,
1390
+ returns_streaming_generator=True,
1391
+ print_output=False
1392
+ )
1393
 
1394
+ progress(0.5, desc="Processing...")
1395
+
1396
+ outputs = ""
1397
+ for output in stream:
1398
+ outputs += output
1399
+ yield [(outputs, None)], gr.update(), gr.update()
1400
+ except Exception as e:
1401
+ print(e)
1402
+ yield [("", None)], gr.update(), gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1403
 
1404
 
1405
  def dolphin_parse_simple(
1406
  message: str,
1407
  history: list[tuple[str, str]],
1408
  ):
 
 
1409
  try:
1410
+ #if not is_japanese(message): return message
1411
+ if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
1412
  msg = history[-1][0]
1413
  raw_prompt = get_raw_prompt(msg)
1414
+ prompts = []
1415
+ if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
1416
+ prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
1417
+ else:
1418
+ prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
1419
+ return ", ".join(prompts)
1420
+ except Exception as e:
1421
+ print(e)
1422
  return ""
 
 
 
 
 
 
1423
 
1424
 
1425
  # https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground
1426
  import cv2
1427
  cv2.setNumThreads(1)
1428
 
1429
+
1430
+ @torch.inference_mode()
1431
  @spaces.GPU()
1432
  def respond_playground(
1433
  message,
 
1440
  top_k,
1441
  repeat_penalty,
1442
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1443
  try:
1444
+ if override_llm_format:
1445
+ chat_template = override_llm_format
1446
+ else:
1447
+ chat_template = llm_models[model][1]
1448
+
1449
+ llm = Llama(
1450
+ model_path=str(Path(f"{llm_models_dir}/{model}")),
1451
+ flash_attn=True,
1452
+ n_gpu_layers=81, # 81
1453
+ n_batch=1024,
1454
+ n_ctx=8192, #8192
1455
+ )
1456
+ provider = LlamaCppPythonProvider(llm)
1457
+
1458
+ agent = LlamaCppAgent(
1459
+ provider,
1460
+ system_prompt=f"{system_message}",
1461
+ predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
1462
+ custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
1463
+ debug_output=False
1464
+ )
1465
+
1466
+ settings = provider.get_provider_default_settings()
1467
+ settings.temperature = temperature
1468
+ settings.top_k = top_k
1469
+ settings.top_p = top_p
1470
+ settings.max_tokens = max_tokens
1471
+ settings.repeat_penalty = repeat_penalty
1472
+ settings.stream = True
1473
+
1474
+ messages = BasicChatHistory()
1475
+
1476
+ # Add user and assistant messages to the history
1477
+ for msn in history:
1478
+ user = {'role': Roles.user, 'content': msn[0]}
1479
+ assistant = {'role': Roles.assistant, 'content': msn[1]}
1480
+ messages.add_message(user)
1481
+ messages.add_message(assistant)
1482
+
1483
+ # Stream the response
1484
  stream = agent.get_chat_response(
1485
  message,
1486
  llm_sampling_settings=settings,
 
1494
  outputs += output
1495
  yield outputs
1496
  except Exception as e:
1497
+ print(e)
1498
+ yield ""
modutils.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import re
6
  from pathlib import Path
7
  from PIL import Image
 
8
  import shutil
9
  import requests
10
  from requests.adapters import HTTPAdapter
@@ -12,11 +13,16 @@ from urllib3.util import Retry
12
  import urllib.parse
13
  import pandas as pd
14
  from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
 
 
 
 
 
15
 
16
 
17
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
18
  HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
19
- directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
20
 
21
 
22
  MODEL_TYPE_DICT = {
@@ -46,7 +52,6 @@ def is_repo_name(s):
46
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
47
 
48
 
49
- from translatepy import Translator
50
  translator = Translator()
51
  def translate_to_en(input: str):
52
  try:
@@ -64,6 +69,7 @@ def get_local_model_list(dir_path):
64
  if file.suffix in valid_extensions:
65
  file_path = str(Path(f"{dir_path}/{file.name}"))
66
  model_list.append(file_path)
 
67
  return model_list
68
 
69
 
@@ -98,21 +104,81 @@ def split_hf_url(url: str):
98
  print(e)
99
 
100
 
101
- def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
102
- hf_token = get_token()
103
  repo_id, filename, subfolder, repo_type = split_hf_url(url)
 
 
 
104
  try:
105
- print(f"Downloading {url} to {directory}")
106
- if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
107
- else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
108
  return path
109
  except Exception as e:
110
- print(f"Failed to download: {e}")
111
  return None
112
 
113
 
114
- def download_things(directory, url, hf_token="", civitai_api_key=""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  url = url.strip()
 
 
116
  if "drive.google.com" in url:
117
  original_dir = os.getcwd()
118
  os.chdir(directory)
@@ -123,18 +189,48 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
123
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
124
  if "/blob/" in url:
125
  url = url.replace("/blob/", "/resolve/")
126
- download_hf_file(directory, url)
 
 
 
 
 
 
127
  elif "civitai.com" in url:
128
- if "?" in url:
129
- url = url.split("?")[0]
130
- if civitai_api_key:
131
- url = url + f"?token={civitai_api_key}"
132
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
133
- else:
134
  print("\033[91mYou need an API key to download Civitai models.\033[0m")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
137
 
 
 
138
 
139
  def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
140
  if not "http" in url and is_repo_name(url) and not Path(url).exists():
@@ -173,7 +269,7 @@ def to_lora_key(path: str):
173
 
174
  def to_lora_path(key: str):
175
  if Path(key).is_file(): return key
176
- path = Path(f"{directory_loras}/{escape_lora_basename(key)}.safetensors")
177
  return str(path)
178
 
179
 
@@ -203,25 +299,21 @@ def save_images(images: list[Image.Image], metadatas: list[str]):
203
  raise Exception(f"Failed to save image file:") from e
204
 
205
 
206
- def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
207
- from datetime import datetime, timezone, timedelta
208
  progress(0, desc="Updating gallery...")
209
- dt_now = datetime.now(timezone(timedelta(hours=9)))
210
- basename = dt_now.strftime('%Y%m%d_%H%M%S_')
211
- i = 1
212
- if not images: return images, gr.update(visible=False)
213
  output_images = []
214
  output_paths = []
215
- for image in images:
216
- filename = basename + str(i) + ".png"
217
- i += 1
218
  oldpath = Path(image[0])
219
  newpath = oldpath
220
  try:
221
  if oldpath.exists():
222
  newpath = oldpath.resolve().rename(Path(filename).resolve())
223
  except Exception as e:
224
- print(e)
225
  finally:
226
  output_paths.append(str(newpath))
227
  output_images.append((str(newpath), str(filename)))
@@ -229,10 +321,47 @@ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
229
  return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
230
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def download_private_repo(repo_id, dir_path, is_replace):
233
- if not hf_read_token: return
234
  try:
235
- snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], use_auth_token=hf_read_token)
236
  except Exception as e:
237
  print(f"Error: Failed to download {repo_id}.")
238
  print(e)
@@ -250,9 +379,9 @@ private_model_path_repo_dict = {} # {"local filepath": "huggingface repo_id", ..
250
  def get_private_model_list(repo_id, dir_path):
251
  global private_model_path_repo_dict
252
  api = HfApi()
253
- if not hf_read_token: return []
254
  try:
255
- files = api.list_repo_files(repo_id, token=hf_read_token)
256
  except Exception as e:
257
  print(f"Error: Failed to list {repo_id}.")
258
  print(e)
@@ -270,11 +399,11 @@ def get_private_model_list(repo_id, dir_path):
270
  def download_private_file(repo_id, path, is_replace):
271
  file = Path(path)
272
  newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
273
- if not hf_read_token or newpath.exists(): return
274
  filename = file.name
275
  dirname = file.parent.name
276
  try:
277
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, use_auth_token=hf_read_token)
278
  except Exception as e:
279
  print(f"Error: Failed to download {filename}.")
280
  print(e)
@@ -404,9 +533,9 @@ def get_private_lora_model_lists():
404
  models1 = []
405
  models2 = []
406
  for repo in HF_LORA_PRIVATE_REPOS1:
407
- models1.extend(get_private_model_list(repo, directory_loras))
408
  for repo in HF_LORA_PRIVATE_REPOS2:
409
- models2.extend(get_private_model_list(repo, directory_loras))
410
  models = list_uniq(models1 + sorted(models2))
411
  private_lora_model_list = models.copy()
412
  return models
@@ -451,7 +580,7 @@ def get_civitai_info(path):
451
 
452
 
453
  def get_lora_model_list():
454
- loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list(directory_loras))
455
  loras.insert(0, "None")
456
  loras.insert(0, "")
457
  return loras
@@ -503,14 +632,14 @@ def update_lora_dict(path):
503
  def download_lora(dl_urls: str):
504
  global loras_url_to_path_dict
505
  dl_path = ""
506
- before = get_local_model_list(directory_loras)
507
  urls = []
508
  for url in [url.strip() for url in dl_urls.split(',')]:
509
- local_path = f"{directory_loras}/{url.split('/')[-1]}"
510
  if not Path(local_path).exists():
511
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
512
  urls.append(url)
513
- after = get_local_model_list(directory_loras)
514
  new_files = list_sub(after, before)
515
  i = 0
516
  for file in new_files:
@@ -761,12 +890,14 @@ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3,
761
  gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
762
 
763
 
764
- def get_my_lora(link_url):
765
- before = get_local_model_list(directory_loras)
 
 
766
  for url in [url.strip() for url in link_url.split(',')]:
767
- if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
768
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
769
- after = get_local_model_list(directory_loras)
770
  new_files = list_sub(after, before)
771
  for file in new_files:
772
  path = Path(file)
@@ -774,11 +905,16 @@ def get_my_lora(link_url):
774
  new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
775
  path.resolve().rename(new_path.resolve())
776
  update_lora_dict(str(new_path))
 
777
  new_lora_model_list = get_lora_model_list()
778
  new_lora_tupled_list = get_all_lora_tupled_list()
779
-
 
 
 
 
780
  return gr.update(
781
- choices=new_lora_tupled_list, value=new_lora_model_list[-1]
782
  ), gr.update(
783
  choices=new_lora_tupled_list
784
  ), gr.update(
@@ -787,6 +923,8 @@ def get_my_lora(link_url):
787
  choices=new_lora_tupled_list
788
  ), gr.update(
789
  choices=new_lora_tupled_list
 
 
790
  )
791
 
792
 
@@ -794,12 +932,12 @@ def upload_file_lora(files, progress=gr.Progress(track_tqdm=True)):
794
  progress(0, desc="Uploading...")
795
  file_paths = [file.name for file in files]
796
  progress(1, desc="Uploaded.")
797
- return gr.update(value=file_paths, visible=True), gr.update(visible=True)
798
 
799
 
800
  def move_file_lora(filepaths):
801
  for file in filepaths:
802
- path = Path(shutil.move(Path(file).resolve(), Path(f"./{directory_loras}").resolve()))
803
  newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
804
  path.resolve().rename(newpath.resolve())
805
  update_lora_dict(str(newpath))
@@ -941,7 +1079,7 @@ def update_civitai_selection(evt: gr.SelectData):
941
  selected = civitai_last_choices[selected_index][1]
942
  return gr.update(value=selected)
943
  except Exception:
944
- return gr.update(visible=True)
945
 
946
 
947
  def select_civitai_lora(search_result):
@@ -1425,3 +1563,78 @@ def get_model_pipeline(repo_id: str):
1425
  else:
1426
  return default
1427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import re
6
  from pathlib import Path
7
  from PIL import Image
8
+ import numpy as np
9
  import shutil
10
  import requests
11
  from requests.adapters import HTTPAdapter
 
13
  import urllib.parse
14
  import pandas as pd
15
  from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
16
+ from translatepy import Translator
17
+ from unidecode import unidecode
18
+ import copy
19
+ from datetime import datetime, timezone, timedelta
20
+ FILENAME_TIMEZONE = timezone(timedelta(hours=9)) # JST
21
 
22
 
23
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
24
  HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
25
+ DIRECTORY_LORAS, HF_READ_TOKEN, HF_TOKEN, CIVITAI_API_KEY)
26
 
27
 
28
  MODEL_TYPE_DICT = {
 
52
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
53
 
54
 
 
55
  translator = Translator()
56
  def translate_to_en(input: str):
57
  try:
 
69
  if file.suffix in valid_extensions:
70
  file_path = str(Path(f"{dir_path}/{file.name}"))
71
  model_list.append(file_path)
72
+ #print('\033[34mFILE: ' + file_path + '\033[0m')
73
  return model_list
74
 
75
 
 
104
  print(e)
105
 
106
 
107
+ def download_hf_file(directory, url, force_filename="", hf_token="", progress=gr.Progress(track_tqdm=True)):
 
108
  repo_id, filename, subfolder, repo_type = split_hf_url(url)
109
+ kwargs = {}
110
+ if subfolder is not None: kwargs["subfolder"] = subfolder
111
+ if force_filename: kwargs["force_filename"] = force_filename
112
  try:
113
+ print(f"Start downloading: {url} to {directory}")
114
+ path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token, **kwargs)
 
115
  return path
116
  except Exception as e:
117
+ print(f"Download failed: {url} {e}")
118
  return None
119
 
120
 
121
+ USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
122
+
123
+
124
+ def request_json_data(url):
125
+ model_version_id = url.split('/')[-1]
126
+ if "?modelVersionId=" in model_version_id:
127
+ match = re.search(r'modelVersionId=(\d+)', url)
128
+ model_version_id = match.group(1)
129
+
130
+ endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
131
+
132
+ params = {}
133
+ headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
134
+ session = requests.Session()
135
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
136
+ session.mount("https://", HTTPAdapter(max_retries=retries))
137
+
138
+ try:
139
+ result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
140
+ result.raise_for_status()
141
+ json_data = result.json()
142
+ return json_data if json_data else None
143
+ except Exception as e:
144
+ print(f"Error: {e}")
145
+ return None
146
+
147
+
148
+ class ModelInformation:
149
+ def __init__(self, json_data):
150
+ self.model_version_id = json_data.get("id", "")
151
+ self.model_id = json_data.get("modelId", "")
152
+ self.download_url = json_data.get("downloadUrl", "")
153
+ self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
154
+ self.filename_url = next(
155
+ (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
156
+ )
157
+ self.filename_url = self.filename_url if self.filename_url else ""
158
+ self.description = json_data.get("description", "")
159
+ if self.description is None: self.description = ""
160
+ self.model_name = json_data.get("model", {}).get("name", "")
161
+ self.model_type = json_data.get("model", {}).get("type", "")
162
+ self.nsfw = json_data.get("model", {}).get("nsfw", False)
163
+ self.poi = json_data.get("model", {}).get("poi", False)
164
+ self.images = [img.get("url", "") for img in json_data.get("images", [])]
165
+ self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
166
+ self.original_json = copy.deepcopy(json_data)
167
+
168
+
169
+ def retrieve_model_info(url):
170
+ json_data = request_json_data(url)
171
+ if not json_data:
172
+ return None
173
+ model_descriptor = ModelInformation(json_data)
174
+ return model_descriptor
175
+
176
+
177
+ def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
178
+ hf_token = get_token()
179
  url = url.strip()
180
+ downloaded_file_path = None
181
+
182
  if "drive.google.com" in url:
183
  original_dir = os.getcwd()
184
  os.chdir(directory)
 
189
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
190
  if "/blob/" in url:
191
  url = url.replace("/blob/", "/resolve/")
192
+
193
+ filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
194
+
195
+ download_hf_file(directory, url, filename, hf_token)
196
+
197
+ downloaded_file_path = os.path.join(directory, filename)
198
+
199
  elif "civitai.com" in url:
200
+
201
+ if not civitai_api_key:
 
 
 
 
202
  print("\033[91mYou need an API key to download Civitai models.\033[0m")
203
+
204
+ model_profile = retrieve_model_info(url)
205
+ if model_profile.download_url and model_profile.filename_url:
206
+ url = model_profile.download_url
207
+ filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
208
+ else:
209
+ if "?" in url:
210
+ url = url.split("?")[0]
211
+ filename = ""
212
+
213
+ url_dl = url + f"?token={civitai_api_key}"
214
+ print(f"Filename: {filename}")
215
+
216
+ param_filename = ""
217
+ if filename:
218
+ param_filename = f"-o '{filename}'"
219
+
220
+ aria2_command = (
221
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
222
+ f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
223
+ )
224
+ os.system(aria2_command)
225
+
226
+ if param_filename and os.path.exists(os.path.join(directory, filename)):
227
+ downloaded_file_path = os.path.join(directory, filename)
228
+
229
  else:
230
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
231
 
232
+ return downloaded_file_path
233
+
234
 
235
  def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
236
  if not "http" in url and is_repo_name(url) and not Path(url).exists():
 
269
 
270
  def to_lora_path(key: str):
271
  if Path(key).is_file(): return key
272
+ path = Path(f"{DIRECTORY_LORAS}/{escape_lora_basename(key)}.safetensors")
273
  return str(path)
274
 
275
 
 
299
  raise Exception(f"Failed to save image file:") from e
300
 
301
 
302
+ def save_gallery_images(images, model_name="", progress=gr.Progress(track_tqdm=True)):
 
303
  progress(0, desc="Updating gallery...")
304
+ basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}_"
305
+ if not images: return images, gr.update()
 
 
306
  output_images = []
307
  output_paths = []
308
+ for i, image in enumerate(images):
309
+ filename = f"{basename}{str(i + 1)}.png"
 
310
  oldpath = Path(image[0])
311
  newpath = oldpath
312
  try:
313
  if oldpath.exists():
314
  newpath = oldpath.resolve().rename(Path(filename).resolve())
315
  except Exception as e:
316
+ print(e)
317
  finally:
318
  output_paths.append(str(newpath))
319
  output_images.append((str(newpath), str(filename)))
 
321
  return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
322
 
323
 
324
+ def save_gallery_history(images, files, history_gallery, history_files, progress=gr.Progress(track_tqdm=True)):
325
+ if not images or not files: return gr.update(), gr.update()
326
+ if not history_gallery: history_gallery = []
327
+ if not history_files: history_files = []
328
+ output_gallery = images + history_gallery
329
+ output_files = files + history_files
330
+ return gr.update(value=output_gallery), gr.update(value=output_files, visible=True)
331
+
332
+
333
+ def save_image_history(image, gallery, files, model_name: str, progress=gr.Progress(track_tqdm=True)):
334
+ if not gallery: gallery = []
335
+ if not files: files = []
336
+ try:
337
+ basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}"
338
+ if image is None or not isinstance(image, (str, Image.Image, np.ndarray, tuple)): return gr.update(), gr.update()
339
+ filename = f"{basename}.png"
340
+ if isinstance(image, tuple): image = image[0]
341
+ if isinstance(image, str): oldpath = image
342
+ elif isinstance(image, Image.Image):
343
+ oldpath = "temp.png"
344
+ image.save(oldpath)
345
+ elif isinstance(image, np.ndarray):
346
+ oldpath = "temp.png"
347
+ Image.fromarray(image).convert('RGBA').save(oldpath)
348
+ oldpath = Path(oldpath)
349
+ newpath = oldpath
350
+ if oldpath.exists():
351
+ shutil.copy(oldpath.resolve(), Path(filename).resolve())
352
+ newpath = Path(filename).resolve()
353
+ files.insert(0, str(newpath))
354
+ gallery.insert(0, (str(newpath), str(filename)))
355
+ except Exception as e:
356
+ print(e)
357
+ finally:
358
+ return gr.update(value=gallery), gr.update(value=files, visible=True)
359
+
360
+
361
  def download_private_repo(repo_id, dir_path, is_replace):
362
+ if not HF_READ_TOKEN: return
363
  try:
364
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], token=HF_READ_TOKEN)
365
  except Exception as e:
366
  print(f"Error: Failed to download {repo_id}.")
367
  print(e)
 
379
  def get_private_model_list(repo_id, dir_path):
380
  global private_model_path_repo_dict
381
  api = HfApi()
382
+ if not HF_READ_TOKEN: return []
383
  try:
384
+ files = api.list_repo_files(repo_id, token=HF_READ_TOKEN)
385
  except Exception as e:
386
  print(f"Error: Failed to list {repo_id}.")
387
  print(e)
 
399
  def download_private_file(repo_id, path, is_replace):
400
  file = Path(path)
401
  newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
402
+ if not HF_READ_TOKEN or newpath.exists(): return
403
  filename = file.name
404
  dirname = file.parent.name
405
  try:
406
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, token=HF_READ_TOKEN)
407
  except Exception as e:
408
  print(f"Error: Failed to download {filename}.")
409
  print(e)
 
533
  models1 = []
534
  models2 = []
535
  for repo in HF_LORA_PRIVATE_REPOS1:
536
+ models1.extend(get_private_model_list(repo, DIRECTORY_LORAS))
537
  for repo in HF_LORA_PRIVATE_REPOS2:
538
+ models2.extend(get_private_model_list(repo, DIRECTORY_LORAS))
539
  models = list_uniq(models1 + sorted(models2))
540
  private_lora_model_list = models.copy()
541
  return models
 
580
 
581
 
582
  def get_lora_model_list():
583
+ loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list(DIRECTORY_LORAS))
584
  loras.insert(0, "None")
585
  loras.insert(0, "")
586
  return loras
 
632
  def download_lora(dl_urls: str):
633
  global loras_url_to_path_dict
634
  dl_path = ""
635
+ before = get_local_model_list(DIRECTORY_LORAS)
636
  urls = []
637
  for url in [url.strip() for url in dl_urls.split(',')]:
638
+ local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
639
  if not Path(local_path).exists():
640
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
641
  urls.append(url)
642
+ after = get_local_model_list(DIRECTORY_LORAS)
643
  new_files = list_sub(after, before)
644
  i = 0
645
  for file in new_files:
 
890
  gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
891
 
892
 
893
+ def get_my_lora(link_url, romanize):
894
+ l_name = ""
895
+ l_path = ""
896
+ before = get_local_model_list(DIRECTORY_LORAS)
897
  for url in [url.strip() for url in link_url.split(',')]:
898
+ if not Path(f"{DIRECTORY_LORAS}/{url.split('/')[-1]}").exists():
899
+ l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
900
+ after = get_local_model_list(DIRECTORY_LORAS)
901
  new_files = list_sub(after, before)
902
  for file in new_files:
903
  path = Path(file)
 
905
  new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
906
  path.resolve().rename(new_path.resolve())
907
  update_lora_dict(str(new_path))
908
+ l_path = str(new_path)
909
  new_lora_model_list = get_lora_model_list()
910
  new_lora_tupled_list = get_all_lora_tupled_list()
911
+ msg_lora = "Downloaded"
912
+ if l_name:
913
+ msg_lora += f": <b>{l_name}</b>"
914
+ print(msg_lora)
915
+
916
  return gr.update(
917
+ choices=new_lora_tupled_list, value=l_path
918
  ), gr.update(
919
  choices=new_lora_tupled_list
920
  ), gr.update(
 
923
  choices=new_lora_tupled_list
924
  ), gr.update(
925
  choices=new_lora_tupled_list
926
+ ), gr.update(
927
+ value=msg_lora
928
  )
929
 
930
 
 
932
  progress(0, desc="Uploading...")
933
  file_paths = [file.name for file in files]
934
  progress(1, desc="Uploaded.")
935
+ return gr.update(value=file_paths, visible=True), gr.update()
936
 
937
 
938
  def move_file_lora(filepaths):
939
  for file in filepaths:
940
+ path = Path(shutil.move(Path(file).resolve(), Path(f"./{DIRECTORY_LORAS}").resolve()))
941
  newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
942
  path.resolve().rename(newpath.resolve())
943
  update_lora_dict(str(newpath))
 
1079
  selected = civitai_last_choices[selected_index][1]
1080
  return gr.update(value=selected)
1081
  except Exception:
1082
+ return gr.update()
1083
 
1084
 
1085
  def select_civitai_lora(search_result):
 
1563
  else:
1564
  return default
1565
 
1566
+
1567
+ EXAMPLES_GUI = [
1568
+ [
1569
+ "1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
1570
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1571
+ 1,
1572
+ 30,
1573
+ 7.5,
1574
+ True,
1575
+ -1,
1576
+ "Euler a",
1577
+ 1152,
1578
+ 896,
1579
+ "votepurchase/animagine-xl-3.1",
1580
+ ],
1581
+ [
1582
+ "solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
1583
+ "score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
1584
+ 1,
1585
+ 30,
1586
+ 5.,
1587
+ True,
1588
+ -1,
1589
+ "Euler a",
1590
+ 1024,
1591
+ 1024,
1592
+ "votepurchase/ponyDiffusionV6XL",
1593
+ ],
1594
+ [
1595
+ "1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
1596
+ "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1597
+ 1,
1598
+ 40,
1599
+ 7.0,
1600
+ True,
1601
+ -1,
1602
+ "Euler a",
1603
+ 1024,
1604
+ 1024,
1605
+ "Raelina/Rae-Diffusion-XL-V2",
1606
+ ],
1607
+ [
1608
+ "1girl, akaza akari, yuru yuri, official art, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
1609
+ "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1610
+ 1,
1611
+ 35,
1612
+ 7.0,
1613
+ True,
1614
+ -1,
1615
+ "Euler a",
1616
+ 1024,
1617
+ 1024,
1618
+ "Raelina/Raemu-XL-V4",
1619
+ ],
1620
+ [
1621
+ "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
1622
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1623
+ 1,
1624
+ 50,
1625
+ 7.,
1626
+ True,
1627
+ -1,
1628
+ "Euler a",
1629
+ 1024,
1630
+ 1024,
1631
+ "cagliostrolab/animagine-xl-3.1",
1632
+ ],
1633
+ ]
1634
+
1635
+
1636
+ RESOURCES = (
1637
+ """### Resources
1638
+ - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
1639
+ """
1640
+ )
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  spaces
2
  accelerate
3
- spaces>=0.30.3
4
  diffusers
5
  invisible_watermark
6
  transformers
@@ -21,4 +20,5 @@ dartrs
21
  translatepy
22
  timm
23
  wrapt-timeout-decorator
24
- sentencepiece
 
 
1
  spaces
2
  accelerate
 
3
  diffusers
4
  invisible_watermark
5
  transformers
 
20
  translatepy
21
  timm
22
  wrapt-timeout-decorator
23
+ sentencepiece
24
+ unidecode
utils.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from constants import (
5
+ DIFFUSERS_FORMAT_LORAS,
6
+ CIVITAI_API_KEY,
7
+ HF_TOKEN,
8
+ MODEL_TYPE_CLASS,
9
+ DIRECTORY_LORAS,
10
+ )
11
+ from huggingface_hub import HfApi
12
+ from diffusers import DiffusionPipeline
13
+ from huggingface_hub import model_info as model_info_data
14
+ from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
15
+ from pathlib import PosixPath
16
+ from unidecode import unidecode
17
+ import urllib.parse
18
+ import copy
19
+ import requests
20
+ from requests.adapters import HTTPAdapter
21
+ from urllib3.util import Retry
22
+
23
+ USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
24
+
25
+
26
+ def request_json_data(url):
27
+ model_version_id = url.split('/')[-1]
28
+ if "?modelVersionId=" in model_version_id:
29
+ match = re.search(r'modelVersionId=(\d+)', url)
30
+ model_version_id = match.group(1)
31
+
32
+ endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
33
+
34
+ params = {}
35
+ headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
36
+ session = requests.Session()
37
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
38
+ session.mount("https://", HTTPAdapter(max_retries=retries))
39
+
40
+ try:
41
+ result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
42
+ result.raise_for_status()
43
+ json_data = result.json()
44
+ return json_data if json_data else None
45
+ except Exception as e:
46
+ print(f"Error: {e}")
47
+ return None
48
+
49
+
50
+ class ModelInformation:
51
+ def __init__(self, json_data):
52
+ self.model_version_id = json_data.get("id", "")
53
+ self.model_id = json_data.get("modelId", "")
54
+ self.download_url = json_data.get("downloadUrl", "")
55
+ self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
56
+ self.filename_url = next(
57
+ (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
58
+ )
59
+ self.filename_url = self.filename_url if self.filename_url else ""
60
+ self.description = json_data.get("description", "")
61
+ if self.description is None: self.description = ""
62
+ self.model_name = json_data.get("model", {}).get("name", "")
63
+ self.model_type = json_data.get("model", {}).get("type", "")
64
+ self.nsfw = json_data.get("model", {}).get("nsfw", False)
65
+ self.poi = json_data.get("model", {}).get("poi", False)
66
+ self.images = [img.get("url", "") for img in json_data.get("images", [])]
67
+ self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
68
+ self.original_json = copy.deepcopy(json_data)
69
+
70
+
71
+ def retrieve_model_info(url):
72
+ json_data = request_json_data(url)
73
+ if not json_data:
74
+ return None
75
+ model_descriptor = ModelInformation(json_data)
76
+ return model_descriptor
77
+
78
+
79
+ def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
80
+ url = url.strip()
81
+ downloaded_file_path = None
82
+
83
+ if "drive.google.com" in url:
84
+ original_dir = os.getcwd()
85
+ os.chdir(directory)
86
+ os.system(f"gdown --fuzzy {url}")
87
+ os.chdir(original_dir)
88
+ elif "huggingface.co" in url:
89
+ url = url.replace("?download=true", "")
90
+ # url = urllib.parse.quote(url, safe=':/') # fix encoding
91
+ if "/blob/" in url:
92
+ url = url.replace("/blob/", "/resolve/")
93
+ user_header = f'"Authorization: Bearer {hf_token}"'
94
+
95
+ filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
96
+
97
+ if hf_token:
98
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
99
+ else:
100
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
101
+
102
+ downloaded_file_path = os.path.join(directory, filename)
103
+
104
+ elif "civitai.com" in url:
105
+
106
+ if not civitai_api_key:
107
+ print("\033[91mYou need an API key to download Civitai models.\033[0m")
108
+
109
+ model_profile = retrieve_model_info(url)
110
+ if model_profile.download_url and model_profile.filename_url:
111
+ url = model_profile.download_url
112
+ filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
113
+ else:
114
+ if "?" in url:
115
+ url = url.split("?")[0]
116
+ filename = ""
117
+
118
+ url_dl = url + f"?token={civitai_api_key}"
119
+ print(f"Filename: {filename}")
120
+
121
+ param_filename = ""
122
+ if filename:
123
+ param_filename = f"-o '{filename}'"
124
+
125
+ aria2_command = (
126
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
127
+ f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
128
+ )
129
+ os.system(aria2_command)
130
+
131
+ if param_filename and os.path.exists(os.path.join(directory, filename)):
132
+ downloaded_file_path = os.path.join(directory, filename)
133
+
134
+ # # PLAN B
135
+ # # Follow the redirect to get the actual download URL
136
+ # curl_command = (
137
+ # f'curl -L -sI --connect-timeout 5 --max-time 5 '
138
+ # f'-H "Content-Type: application/json" '
139
+ # f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
140
+ # )
141
+
142
+ # headers = os.popen(curl_command).read()
143
+
144
+ # # Look for the redirected "Location" URL
145
+ # location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
146
+
147
+ # if location_match:
148
+ # redirect_url = location_match.group(1).strip()
149
+
150
+ # # Extract the filename from the redirect URL's "Content-Disposition"
151
+ # filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
152
+ # if filename_match:
153
+ # encoded_filename = filename_match.group(1)
154
+ # # Decode the URL-encoded filename
155
+ # decoded_filename = urllib.parse.unquote(encoded_filename)
156
+
157
+ # filename = unidecode(decoded_filename) if romanize else decoded_filename
158
+ # print(f"Filename: {filename}")
159
+
160
+ # aria2_command = (
161
+ # f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
162
+ # f'-k 1M -s 16 -d "{directory}" -o "{filename}" "{redirect_url}"'
163
+ # )
164
+ # return_code = os.system(aria2_command)
165
+
166
+ # # if return_code != 0:
167
+ # # raise RuntimeError(f"Failed to download file: {filename}. Error code: {return_code}")
168
+ # downloaded_file_path = os.path.join(directory, filename)
169
+ # if not os.path.exists(downloaded_file_path):
170
+ # downloaded_file_path = None
171
+
172
+ # if not downloaded_file_path:
173
+ # # Old method
174
+ # if "?" in url:
175
+ # url = url.split("?")[0]
176
+ # url = url + f"?token={civitai_api_key}"
177
+ # os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
178
+
179
+ else:
180
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
181
+
182
+ return downloaded_file_path
183
+
184
+
185
+ def get_model_list(directory_path):
186
+ model_list = []
187
+ valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
188
+
189
+ for filename in os.listdir(directory_path):
190
+ if os.path.splitext(filename)[1] in valid_extensions:
191
+ # name_without_extension = os.path.splitext(filename)[0]
192
+ file_path = os.path.join(directory_path, filename)
193
+ # model_list.append((name_without_extension, file_path))
194
+ model_list.append(file_path)
195
+ print('\033[34mFILE: ' + file_path + '\033[0m')
196
+ return model_list
197
+
198
+
199
+ def extract_parameters(input_string):
200
+ parameters = {}
201
+ input_string = input_string.replace("\n", "")
202
+
203
+ if "Negative prompt:" not in input_string:
204
+ if "Steps:" in input_string:
205
+ input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
206
+ else:
207
+ print("Invalid metadata")
208
+ parameters["prompt"] = input_string
209
+ return parameters
210
+
211
+ parm = input_string.split("Negative prompt:")
212
+ parameters["prompt"] = parm[0].strip()
213
+ if "Steps:" not in parm[1]:
214
+ print("Steps not detected")
215
+ parameters["neg_prompt"] = parm[1].strip()
216
+ return parameters
217
+ parm = parm[1].split("Steps:")
218
+ parameters["neg_prompt"] = parm[0].strip()
219
+ input_string = "Steps:" + parm[1]
220
+
221
+ # Extracting Steps
222
+ steps_match = re.search(r'Steps: (\d+)', input_string)
223
+ if steps_match:
224
+ parameters['Steps'] = int(steps_match.group(1))
225
+
226
+ # Extracting Size
227
+ size_match = re.search(r'Size: (\d+x\d+)', input_string)
228
+ if size_match:
229
+ parameters['Size'] = size_match.group(1)
230
+ width, height = map(int, parameters['Size'].split('x'))
231
+ parameters['width'] = width
232
+ parameters['height'] = height
233
+
234
+ # Extracting other parameters
235
+ other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
236
+ for param in other_parameters:
237
+ parameters[param[0]] = param[1].strip('"')
238
+
239
+ return parameters
240
+
241
+
242
+ def get_my_lora(link_url, romanize):
243
+ l_name = ""
244
+ for url in [url.strip() for url in link_url.split(',')]:
245
+ if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
246
+ l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
247
+ new_lora_model_list = get_model_list(DIRECTORY_LORAS)
248
+ new_lora_model_list.insert(0, "None")
249
+ new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
250
+ msg_lora = "Downloaded"
251
+ if l_name:
252
+ msg_lora += f": <b>{l_name}</b>"
253
+ print(msg_lora)
254
+
255
+ return gr.update(
256
+ choices=new_lora_model_list
257
+ ), gr.update(
258
+ choices=new_lora_model_list
259
+ ), gr.update(
260
+ choices=new_lora_model_list
261
+ ), gr.update(
262
+ choices=new_lora_model_list
263
+ ), gr.update(
264
+ choices=new_lora_model_list
265
+ ), gr.update(
266
+ value=msg_lora
267
+ )
268
+
269
+
270
+ def info_html(json_data, title, subtitle):
271
+ return f"""
272
+ <div style='padding: 0; border-radius: 10px;'>
273
+ <p style='margin: 0; font-weight: bold;'>{title}</p>
274
+ <details>
275
+ <summary>Details</summary>
276
+ <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
277
+ </details>
278
+ </div>
279
+ """
280
+
281
+
282
+ def get_model_type(repo_id: str):
283
+ api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
284
+ default = "SD 1.5"
285
+ try:
286
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
287
+ tags = model.tags
288
+ for tag in tags:
289
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
290
+ except Exception:
291
+ return default
292
+ return default
293
+
294
+
295
+ def restart_space(repo_id: str, factory_reboot: bool):
296
+ api = HfApi(token=os.environ.get("HF_TOKEN"))
297
+ try:
298
+ runtime = api.get_space_runtime(repo_id=repo_id)
299
+ if runtime.stage == "RUNNING":
300
+ api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
301
+ print(f"Restarting space: {repo_id}")
302
+ else:
303
+ print(f"Space {repo_id} is in stage: {runtime.stage}")
304
+ except Exception as e:
305
+ print(e)
306
+
307
+
308
+ def extract_exif_data(image):
309
+ if image is None: return ""
310
+
311
+ try:
312
+ metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
313
+
314
+ for key in metadata_keys:
315
+ if key in image.info:
316
+ return image.info[key]
317
+
318
+ return str(image.info)
319
+
320
+ except Exception as e:
321
+ return f"Error extracting metadata: {str(e)}"
322
+
323
+
324
+ def create_mask_now(img, invert):
325
+ import numpy as np
326
+ import time
327
+
328
+ time.sleep(0.5)
329
+
330
+ transparent_image = img["layers"][0]
331
+
332
+ # Extract the alpha channel
333
+ alpha_channel = np.array(transparent_image)[:, :, 3]
334
+
335
+ # Create a binary mask by thresholding the alpha channel
336
+ binary_mask = alpha_channel > 1
337
+
338
+ if invert:
339
+ print("Invert")
340
+ # Invert the binary mask so that the drawn shape is white and the rest is black
341
+ binary_mask = np.invert(binary_mask)
342
+
343
+ # Convert the binary mask to a 3-channel RGB mask
344
+ rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
345
+
346
+ # Convert the mask to uint8
347
+ rgb_mask = rgb_mask.astype(np.uint8) * 255
348
+
349
+ return img["background"], rgb_mask
350
+
351
+
352
+ def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
353
+
354
+ variant = None
355
+ if token is True and not os.environ.get("HF_TOKEN"):
356
+ token = None
357
+
358
+ if model_type == "SDXL":
359
+ info = model_info_data(
360
+ repo_name,
361
+ token=token,
362
+ revision=revision,
363
+ timeout=5.0,
364
+ )
365
+
366
+ filenames = {sibling.rfilename for sibling in info.siblings}
367
+ model_filenames, variant_filenames = variant_compatible_siblings(
368
+ filenames, variant="fp16"
369
+ )
370
+
371
+ if len(variant_filenames):
372
+ variant = "fp16"
373
+
374
+ cached_folder = DiffusionPipeline.download(
375
+ pretrained_model_name=repo_name,
376
+ force_download=False,
377
+ token=token,
378
+ revision=revision,
379
+ # mirror="https://hf-mirror.com",
380
+ variant=variant,
381
+ use_safetensors=True,
382
+ trust_remote_code=False,
383
+ timeout=5.0,
384
+ )
385
+
386
+ if isinstance(cached_folder, PosixPath):
387
+ cached_folder = cached_folder.as_posix()
388
+
389
+ # Task model
390
+ # from huggingface_hub import hf_hub_download
391
+ # hf_hub_download(
392
+ # task_model,
393
+ # filename="diffusion_pytorch_model.safetensors", # fix fp16 variant
394
+ # )
395
+
396
+ return cached_folder
397
+
398
+
399
+ def progress_step_bar(step, total):
400
+ # Calculate the percentage for the progress bar width
401
+ percentage = min(100, ((step / total) * 100))
402
+
403
+ return f"""
404
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
405
+ <div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
406
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
407
+ {int(percentage)}%
408
+ </div>
409
+ </div>
410
+ """
411
+
412
+
413
+ def html_template_message(msg):
414
+ return f"""
415
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
416
+ <div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
417
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
418
+ {msg}
419
+ </div>
420
+ </div>
421
+ """