VideoAditor commited on
Commit
3ff0cf3
·
verified ·
1 Parent(s): de59d95

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +1116 -4
app.py CHANGED
@@ -1,7 +1,1119 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
4
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = '0'
5
+ sys.path.insert(0, os.getcwd())
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts'))
7
+ import subprocess
8
  import gradio as gr
9
+ from PIL import Image
10
+ import torch
11
+ import uuid
12
+ import shutil
13
+ import json
14
+ import yaml
15
+ from slugify import slugify
16
+ from transformers import AutoProcessor, AutoModelForCausalLM
17
+ from gradio_logsview import LogsView, LogsViewRunner
18
+ from huggingface_hub import hf_hub_download, HfApi
19
+ from library import flux_train_utils, huggingface_util
20
+ from argparse import Namespace
21
+ import train_network
22
+ import toml
23
+ import re
24
+ MAX_IMAGES = 150
25
 
26
+ with open('models.yaml', 'r') as file:
27
+ models = yaml.safe_load(file)
28
 
29
+ def readme(base_model, lora_name, instance_prompt, sample_prompts):
30
+
31
+ # model license
32
+ model_config = models[base_model]
33
+ model_file = model_config["file"]
34
+ base_model_name = model_config["base"]
35
+ license = None
36
+ license_name = None
37
+ license_link = None
38
+ license_items = []
39
+ if "license" in model_config:
40
+ license = model_config["license"]
41
+ license_items.append(f"license: {license}")
42
+ if "license_name" in model_config:
43
+ license_name = model_config["license_name"]
44
+ license_items.append(f"license_name: {license_name}")
45
+ if "license_link" in model_config:
46
+ license_link = model_config["license_link"]
47
+ license_items.append(f"license_link: {license_link}")
48
+ license_str = "\n".join(license_items)
49
+ print(f"license_items={license_items}")
50
+ print(f"license_str = {license_str}")
51
+
52
+ # tags
53
+ tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ]
54
+
55
+ # widgets
56
+ widgets = []
57
+ sample_image_paths = []
58
+ output_name = slugify(lora_name)
59
+ samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample")
60
+ try:
61
+ for filename in os.listdir(samples_dir):
62
+ # Filename Schema: [name]_[steps]_[index]_[timestamp].png
63
+ match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename)
64
+ if match:
65
+ steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3))
66
+ sample_image_paths.append((steps, index, f"sample/{filename}"))
67
+
68
+ # Sort by numeric index
69
+ sample_image_paths.sort(key=lambda x: x[0], reverse=True)
70
+
71
+ final_sample_image_paths = sample_image_paths[:len(sample_prompts)]
72
+ final_sample_image_paths.sort(key=lambda x: x[1])
73
+ for i, prompt in enumerate(sample_prompts):
74
+ _, _, image_path = final_sample_image_paths[i]
75
+ widgets.append(
76
+ {
77
+ "text": prompt,
78
+ "output": {
79
+ "url": image_path
80
+ },
81
+ }
82
+ )
83
+ except:
84
+ print(f"no samples")
85
+ dtype = "torch.bfloat16"
86
+ # Construct the README content
87
+ readme_content = f"""---
88
+ tags:
89
+ {yaml.dump(tags, indent=4).strip()}
90
+ {"widget:" if os.path.isdir(samples_dir) else ""}
91
+ {yaml.dump(widgets, indent=4).strip() if widgets else ""}
92
+ base_model: {base_model_name}
93
+ {"instance_prompt: " + instance_prompt if instance_prompt else ""}
94
+ {license_str}
95
+ ---
96
+
97
+ # {lora_name}
98
+
99
+ A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym)
100
+
101
+ <Gallery />
102
+
103
+ ## Trigger words
104
+
105
+ {"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
106
+
107
+ ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc.
108
+
109
+ Weights for this model are available in Safetensors format.
110
+
111
+ """
112
+ return readme_content
113
+
114
+ def account_hf():
115
+ try:
116
+ with open("HF_TOKEN", "r") as file:
117
+ token = file.read()
118
+ api = HfApi(token=token)
119
+ try:
120
+ account = api.whoami()
121
+ return { "token": token, "account": account['name'] }
122
+ except:
123
+ return None
124
+ except:
125
+ return None
126
+
127
+ """
128
+ hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
129
+ """
130
+ def logout_hf():
131
+ os.remove("HF_TOKEN")
132
+ global current_account
133
+ current_account = account_hf()
134
+ print(f"current_account={current_account}")
135
+ return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
136
+
137
+
138
+ """
139
+ hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
140
+ """
141
+ def login_hf(hf_token):
142
+ api = HfApi(token=hf_token)
143
+ try:
144
+ account = api.whoami()
145
+ if account != None:
146
+ if "name" in account:
147
+ with open("HF_TOKEN", "w") as file:
148
+ file.write(hf_token)
149
+ global current_account
150
+ current_account = account_hf()
151
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
152
+ return gr.update(), gr.update(), gr.update(), gr.update()
153
+ except:
154
+ print(f"incorrect hf_token")
155
+ return gr.update(), gr.update(), gr.update(), gr.update()
156
+
157
+ def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token):
158
+ src = lora_rows
159
+ repo_id = f"{repo_owner}/{repo_name}"
160
+ gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None)
161
+ args = Namespace(
162
+ huggingface_repo_id=repo_id,
163
+ huggingface_repo_type="model",
164
+ huggingface_repo_visibility=repo_visibility,
165
+ huggingface_path_in_repo="",
166
+ huggingface_token=hf_token,
167
+ async_upload=False
168
+ )
169
+ print(f"upload_hf args={args}")
170
+ huggingface_util.upload(args=args, src=src)
171
+ gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None)
172
+
173
+ def load_captioning(uploaded_files, concept_sentence):
174
+ uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
175
+ txt_files = [file for file in uploaded_files if file.endswith('.txt')]
176
+ txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
177
+ updates = []
178
+ if len(uploaded_images) <= 1:
179
+ raise gr.Error(
180
+ "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
181
+ )
182
+ elif len(uploaded_images) > MAX_IMAGES:
183
+ raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
184
+ # Update for the captioning_area
185
+ # for _ in range(3):
186
+ updates.append(gr.update(visible=True))
187
+ # Update visibility and image for each captioning row and image
188
+ for i in range(1, MAX_IMAGES + 1):
189
+ # Determine if the current row and image should be visible
190
+ visible = i <= len(uploaded_images)
191
+
192
+ # Update visibility of the captioning row
193
+ updates.append(gr.update(visible=visible))
194
+
195
+ # Update for image component - display image if available, otherwise hide
196
+ image_value = uploaded_images[i - 1] if visible else None
197
+ updates.append(gr.update(value=image_value, visible=visible))
198
+
199
+ corresponding_caption = False
200
+ if(image_value):
201
+ base_name = os.path.splitext(os.path.basename(image_value))[0]
202
+ if base_name in txt_files_dict:
203
+ with open(txt_files_dict[base_name], 'r') as file:
204
+ corresponding_caption = file.read()
205
+
206
+ # Update value of captioning area
207
+ text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None
208
+ updates.append(gr.update(value=text_value, visible=visible))
209
+
210
+ # Update for the sample caption area
211
+ updates.append(gr.update(visible=True))
212
+ updates.append(gr.update(visible=True))
213
+
214
+ return updates
215
+
216
+ def hide_captioning():
217
+ return gr.update(visible=False), gr.update(visible=False)
218
+
219
+ def resize_image(image_path, output_path, size):
220
+ with Image.open(image_path) as img:
221
+ width, height = img.size
222
+ if width < height:
223
+ new_width = size
224
+ new_height = int((size/width) * height)
225
+ else:
226
+ new_height = size
227
+ new_width = int((size/height) * width)
228
+ print(f"resize {image_path} : {new_width}x{new_height}")
229
+ img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
230
+ img_resized.save(output_path)
231
+
232
+ def create_dataset(destination_folder, size, *inputs):
233
+ print("Creating dataset")
234
+ images = inputs[0]
235
+ if not os.path.exists(destination_folder):
236
+ os.makedirs(destination_folder)
237
+
238
+ for index, image in enumerate(images):
239
+ # copy the images to the datasets folder
240
+ new_image_path = shutil.copy(image, destination_folder)
241
+
242
+ # if it's a caption text file skip the next bit
243
+ ext = os.path.splitext(new_image_path)[-1].lower()
244
+ if ext == '.txt':
245
+ continue
246
+
247
+ # resize the images
248
+ resize_image(new_image_path, new_image_path, size)
249
+
250
+ # copy the captions
251
+
252
+ original_caption = inputs[index + 1]
253
+
254
+ image_file_name = os.path.basename(new_image_path)
255
+ caption_file_name = os.path.splitext(image_file_name)[0] + ".txt"
256
+ caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name))
257
+ print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}")
258
+ # if caption_path exists, do not write
259
+ if os.path.exists(caption_path):
260
+ print(f"{caption_path} already exists. use the existing .txt file")
261
+ else:
262
+ print(f"{caption_path} create a .txt caption file")
263
+ with open(caption_path, 'w') as file:
264
+ file.write(original_caption)
265
+
266
+ print(f"destination_folder {destination_folder}")
267
+ return destination_folder
268
+
269
+
270
+ def run_captioning(images, concept_sentence, *captions):
271
+ print(f"run_captioning")
272
+ print(f"concept sentence {concept_sentence}")
273
+ print(f"captions {captions}")
274
+ #Load internally to not consume resources for training
275
+ device = "cuda" if torch.cuda.is_available() else "cpu"
276
+ print(f"device={device}")
277
+ torch_dtype = torch.float16
278
+ model = AutoModelForCausalLM.from_pretrained(
279
+ "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
280
+ ).to(device)
281
+ processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
282
+
283
+ captions = list(captions)
284
+ for i, image_path in enumerate(images):
285
+ print(captions[i])
286
+ if isinstance(image_path, str): # If image is a file path
287
+ image = Image.open(image_path).convert("RGB")
288
+
289
+ prompt = "<DETAILED_CAPTION>"
290
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
291
+ print(f"inputs {inputs}")
292
+
293
+ generated_ids = model.generate(
294
+ input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
295
+ )
296
+ print(f"generated_ids {generated_ids}")
297
+
298
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
299
+ print(f"generated_text: {generated_text}")
300
+ parsed_answer = processor.post_process_generation(
301
+ generated_text, task=prompt, image_size=(image.width, image.height)
302
+ )
303
+ print(f"parsed_answer = {parsed_answer}")
304
+ caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
305
+ print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}")
306
+ if concept_sentence:
307
+ caption_text = f"{concept_sentence} {caption_text}"
308
+ captions[i] = caption_text
309
+
310
+ yield captions
311
+ model.to("cpu")
312
+ del model
313
+ del processor
314
+ if torch.cuda.is_available():
315
+ torch.cuda.empty_cache()
316
+
317
+ def recursive_update(d, u):
318
+ for k, v in u.items():
319
+ if isinstance(v, dict) and v:
320
+ d[k] = recursive_update(d.get(k, {}), v)
321
+ else:
322
+ d[k] = v
323
+ return d
324
+
325
+ def download(base_model):
326
+ model = models[base_model]
327
+ model_file = model["file"]
328
+ repo = model["repo"]
329
+
330
+ # download unet
331
+ if base_model == "flux-dev" or base_model == "flux-schnell":
332
+ unet_folder = "models/unet"
333
+ else:
334
+ unet_folder = f"models/unet/{repo}"
335
+ unet_path = os.path.join(unet_folder, model_file)
336
+ if not os.path.exists(unet_path):
337
+ os.makedirs(unet_folder, exist_ok=True)
338
+ gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None)
339
+ print(f"download {base_model}")
340
+ hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file)
341
+
342
+ # download vae
343
+ vae_folder = "models/vae"
344
+ vae_path = os.path.join(vae_folder, "ae.sft")
345
+ if not os.path.exists(vae_path):
346
+ os.makedirs(vae_folder, exist_ok=True)
347
+ gr.Info(f"Downloading vae")
348
+ print(f"downloading ae.sft...")
349
+ hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft")
350
+
351
+ # download clip
352
+ clip_folder = "models/clip"
353
+ clip_l_path = os.path.join(clip_folder, "clip_l.safetensors")
354
+ if not os.path.exists(clip_l_path):
355
+ os.makedirs(clip_folder, exist_ok=True)
356
+ gr.Info(f"Downloading clip...")
357
+ print(f"download clip_l.safetensors")
358
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors")
359
+
360
+ # download t5xxl
361
+ t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors")
362
+ if not os.path.exists(t5xxl_path):
363
+ print(f"download t5xxl_fp16.safetensors")
364
+ gr.Info(f"Downloading t5xxl...")
365
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors")
366
+
367
+
368
+ def resolve_path(p):
369
+ current_dir = os.path.dirname(os.path.abspath(__file__))
370
+ norm_path = os.path.normpath(os.path.join(current_dir, p))
371
+ return f"\"{norm_path}\""
372
+ def resolve_path_without_quotes(p):
373
+ current_dir = os.path.dirname(os.path.abspath(__file__))
374
+ norm_path = os.path.normpath(os.path.join(current_dir, p))
375
+ return norm_path
376
+
377
+ def gen_sh(
378
+ base_model,
379
+ output_name,
380
+ resolution,
381
+ seed,
382
+ workers,
383
+ learning_rate,
384
+ network_dim,
385
+ max_train_epochs,
386
+ save_every_n_epochs,
387
+ timestep_sampling,
388
+ guidance_scale,
389
+ vram,
390
+ sample_prompts,
391
+ sample_every_n_steps,
392
+ *advanced_components
393
+ ):
394
+
395
+ print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}")
396
+
397
+ output_dir = resolve_path(f"outputs/{output_name}")
398
+ sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt")
399
+
400
+ line_break = "\\"
401
+ file_type = "sh"
402
+ if sys.platform == "win32":
403
+ line_break = "^"
404
+ file_type = "bat"
405
+
406
+ ############# Sample args ########################
407
+ sample = ""
408
+ if len(sample_prompts) > 0 and sample_every_n_steps > 0:
409
+ sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}"""
410
+
411
+
412
+ ############# Optimizer args ########################
413
+ # if vram == "8G":
414
+ # optimizer = f"""--optimizer_type adafactor {line_break}
415
+ # --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
416
+ # --split_mode {line_break}
417
+ # --network_args "train_blocks=single" {line_break}
418
+ # --lr_scheduler constant_with_warmup {line_break}
419
+ # --max_grad_norm 0.0 {line_break}"""
420
+ if vram == "16G":
421
+ # 16G VRAM
422
+ optimizer = f"""--optimizer_type adafactor {line_break}
423
+ --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
424
+ --lr_scheduler constant_with_warmup {line_break}
425
+ --max_grad_norm 0.0 {line_break}"""
426
+ elif vram == "12G":
427
+ # 12G VRAM
428
+ optimizer = f"""--optimizer_type adafactor {line_break}
429
+ --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
430
+ --split_mode {line_break}
431
+ --network_args "train_blocks=single" {line_break}
432
+ --lr_scheduler constant_with_warmup {line_break}
433
+ --max_grad_norm 0.0 {line_break}"""
434
+ else:
435
+ # 20G+ VRAM
436
+ optimizer = f"--optimizer_type adamw8bit {line_break}"
437
+
438
+
439
+ #######################################################
440
+ model_config = models[base_model]
441
+ model_file = model_config["file"]
442
+ repo = model_config["repo"]
443
+ if base_model == "flux-dev" or base_model == "flux-schnell":
444
+ model_folder = "models/unet"
445
+ else:
446
+ model_folder = f"models/unet/{repo}"
447
+ model_path = os.path.join(model_folder, model_file)
448
+ pretrained_model_path = resolve_path(model_path)
449
+
450
+ clip_path = resolve_path("models/clip/clip_l.safetensors")
451
+ t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors")
452
+ ae_path = resolve_path("models/vae/ae.sft")
453
+ sh = f"""accelerate launch {line_break}
454
+ --mixed_precision bf16 {line_break}
455
+ --num_cpu_threads_per_process 1 {line_break}
456
+ sd-scripts/flux_train_network.py {line_break}
457
+ --pretrained_model_name_or_path {pretrained_model_path} {line_break}
458
+ --clip_l {clip_path} {line_break}
459
+ --t5xxl {t5_path} {line_break}
460
+ --ae {ae_path} {line_break}
461
+ --cache_latents_to_disk {line_break}
462
+ --save_model_as safetensors {line_break}
463
+ --sdpa --persistent_data_loader_workers {line_break}
464
+ --max_data_loader_n_workers {workers} {line_break}
465
+ --seed {seed} {line_break}
466
+ --gradient_checkpointing {line_break}
467
+ --mixed_precision bf16 {line_break}
468
+ --save_precision bf16 {line_break}
469
+ --network_module networks.lora_flux {line_break}
470
+ --network_dim {network_dim} {line_break}
471
+ {optimizer}{sample}
472
+ --learning_rate {learning_rate} {line_break}
473
+ --cache_text_encoder_outputs {line_break}
474
+ --cache_text_encoder_outputs_to_disk {line_break}
475
+ --fp8_base {line_break}
476
+ --highvram {line_break}
477
+ --max_train_epochs {max_train_epochs} {line_break}
478
+ --save_every_n_epochs {save_every_n_epochs} {line_break}
479
+ --dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break}
480
+ --output_dir {output_dir} {line_break}
481
+ --output_name {output_name} {line_break}
482
+ --timestep_sampling {timestep_sampling} {line_break}
483
+ --discrete_flow_shift 3.1582 {line_break}
484
+ --model_prediction_type raw {line_break}
485
+ --guidance_scale {guidance_scale} {line_break}
486
+ --loss_type l2 {line_break}"""
487
+
488
+
489
+
490
+ ############# Advanced args ########################
491
+ global advanced_component_ids
492
+ global original_advanced_component_values
493
+
494
+ # check dirty
495
+ print(f"original_advanced_component_values = {original_advanced_component_values}")
496
+ advanced_flags = []
497
+ for i, current_value in enumerate(advanced_components):
498
+ # print(f"compare {advanced_component_ids[i]}: old={original_advanced_component_values[i]}, new={current_value}")
499
+ if original_advanced_component_values[i] != current_value:
500
+ # dirty
501
+ if current_value == True:
502
+ # Boolean
503
+ advanced_flags.append(advanced_component_ids[i])
504
+ else:
505
+ # string
506
+ advanced_flags.append(f"{advanced_component_ids[i]} {current_value}")
507
+
508
+ if len(advanced_flags) > 0:
509
+ advanced_flags_str = f" {line_break}\n ".join(advanced_flags)
510
+ sh = sh + "\n " + advanced_flags_str
511
+
512
+ return sh
513
+
514
+ def gen_toml(
515
+ dataset_folder,
516
+ resolution,
517
+ class_tokens,
518
+ num_repeats
519
+ ):
520
+ toml = f"""[general]
521
+ shuffle_caption = false
522
+ caption_extension = '.txt'
523
+ keep_tokens = 1
524
+
525
+ [[datasets]]
526
+ resolution = {resolution}
527
+ batch_size = 1
528
+ keep_tokens = 1
529
+
530
+ [[datasets.subsets]]
531
+ image_dir = '{resolve_path_without_quotes(dataset_folder)}'
532
+ class_tokens = '{class_tokens}'
533
+ num_repeats = {num_repeats}"""
534
+ return toml
535
+
536
+ def update_total_steps(max_train_epochs, num_repeats, images):
537
+ try:
538
+ num_images = len(images)
539
+ total_steps = max_train_epochs * num_images * num_repeats
540
+ print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}")
541
+ return gr.update(value = total_steps)
542
+ except:
543
+ print("")
544
+
545
+ def set_repo(lora_rows):
546
+ selected_name = os.path.basename(lora_rows)
547
+ return gr.update(value=selected_name)
548
+
549
+ def get_loras():
550
+ try:
551
+ outputs_path = resolve_path_without_quotes(f"outputs")
552
+ files = os.listdir(outputs_path)
553
+ folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"]
554
+ folders.sort(key=lambda file: os.path.getctime(file), reverse=True)
555
+ return folders
556
+ except Exception as e:
557
+ return []
558
+
559
+ def get_samples(lora_name):
560
+ output_name = slugify(lora_name)
561
+ try:
562
+ samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample")
563
+ files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)]
564
+ files.sort(key=lambda file: os.path.getctime(file), reverse=True)
565
+ return files
566
+ except:
567
+ return []
568
+
569
+ def start_training(
570
+ base_model,
571
+ lora_name,
572
+ train_script,
573
+ train_config,
574
+ sample_prompts,
575
+ ):
576
+ # write custom script and toml
577
+ if not os.path.exists("models"):
578
+ os.makedirs("models", exist_ok=True)
579
+ if not os.path.exists("outputs"):
580
+ os.makedirs("outputs", exist_ok=True)
581
+ output_name = slugify(lora_name)
582
+ output_dir = resolve_path_without_quotes(f"outputs/{output_name}")
583
+ if not os.path.exists(output_dir):
584
+ os.makedirs(output_dir, exist_ok=True)
585
+
586
+ download(base_model)
587
+
588
+ file_type = "sh"
589
+ if sys.platform == "win32":
590
+ file_type = "bat"
591
+
592
+ sh_filename = f"train.{file_type}"
593
+ sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}")
594
+ with open(sh_filepath, 'w', encoding="utf-8") as file:
595
+ file.write(train_script)
596
+ gr.Info(f"Generated train script at {sh_filename}")
597
+
598
+
599
+ dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml")
600
+ with open(dataset_path, 'w', encoding="utf-8") as file:
601
+ file.write(train_config)
602
+ gr.Info(f"Generated dataset.toml")
603
+
604
+ sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
605
+ with open(sample_prompts_path, 'w', encoding='utf-8') as file:
606
+ file.write(sample_prompts)
607
+ gr.Info(f"Generated sample_prompts.txt")
608
+
609
+ # Train
610
+ if sys.platform == "win32":
611
+ command = sh_filepath
612
+ else:
613
+ command = f"bash \"{sh_filepath}\""
614
+
615
+ # Use Popen to run the command and capture output in real-time
616
+ env = os.environ.copy()
617
+ env['PYTHONIOENCODING'] = 'utf-8'
618
+ env['LOG_LEVEL'] = 'DEBUG'
619
+ runner = LogsViewRunner()
620
+ cwd = os.path.dirname(os.path.abspath(__file__))
621
+ gr.Info(f"Started training")
622
+ yield from runner.run_command([command], cwd=cwd)
623
+ yield runner.log(f"Runner: {runner}")
624
+
625
+ # Generate Readme
626
+ config = toml.loads(train_config)
627
+ concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens']
628
+ print(f"concept_sentence={concept_sentence}")
629
+ print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}")
630
+ sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
631
+ with open(sample_prompts_path, "r", encoding="utf-8") as f:
632
+ lines = f.readlines()
633
+ sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
634
+ md = readme(base_model, lora_name, concept_sentence, sample_prompts)
635
+ readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md")
636
+ with open(readme_path, "w", encoding="utf-8") as f:
637
+ f.write(md)
638
+
639
+ gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None)
640
+
641
+
642
+ def update(
643
+ base_model,
644
+ lora_name,
645
+ resolution,
646
+ seed,
647
+ workers,
648
+ class_tokens,
649
+ learning_rate,
650
+ network_dim,
651
+ max_train_epochs,
652
+ save_every_n_epochs,
653
+ timestep_sampling,
654
+ guidance_scale,
655
+ vram,
656
+ num_repeats,
657
+ sample_prompts,
658
+ sample_every_n_steps,
659
+ *advanced_components,
660
+ ):
661
+ output_name = slugify(lora_name)
662
+ dataset_folder = str(f"datasets/{output_name}")
663
+ sh = gen_sh(
664
+ base_model,
665
+ output_name,
666
+ resolution,
667
+ seed,
668
+ workers,
669
+ learning_rate,
670
+ network_dim,
671
+ max_train_epochs,
672
+ save_every_n_epochs,
673
+ timestep_sampling,
674
+ guidance_scale,
675
+ vram,
676
+ sample_prompts,
677
+ sample_every_n_steps,
678
+ *advanced_components,
679
+ )
680
+ toml = gen_toml(
681
+ dataset_folder,
682
+ resolution,
683
+ class_tokens,
684
+ num_repeats
685
+ )
686
+ return gr.update(value=sh), gr.update(value=toml), dataset_folder
687
+
688
+ """
689
+ demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account])
690
+ """
691
+ def loaded():
692
+ global current_account
693
+ current_account = account_hf()
694
+ print(f"current_account={current_account}")
695
+ if current_account != None:
696
+ return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
697
+ else:
698
+ return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
699
+
700
+ def update_sample(concept_sentence):
701
+ return gr.update(value=concept_sentence)
702
+
703
+ def refresh_publish_tab():
704
+ loras = get_loras()
705
+ return gr.Dropdown(label="Trained LoRAs", choices=loras)
706
+
707
+ def init_advanced():
708
+ # if basic_args
709
+ basic_args = {
710
+ 'pretrained_model_name_or_path',
711
+ 'clip_l',
712
+ 't5xxl',
713
+ 'ae',
714
+ 'cache_latents_to_disk',
715
+ 'save_model_as',
716
+ 'sdpa',
717
+ 'persistent_data_loader_workers',
718
+ 'max_data_loader_n_workers',
719
+ 'seed',
720
+ 'gradient_checkpointing',
721
+ 'mixed_precision',
722
+ 'save_precision',
723
+ 'network_module',
724
+ 'network_dim',
725
+ 'learning_rate',
726
+ 'cache_text_encoder_outputs',
727
+ 'cache_text_encoder_outputs_to_disk',
728
+ 'fp8_base',
729
+ 'highvram',
730
+ 'max_train_epochs',
731
+ 'save_every_n_epochs',
732
+ 'dataset_config',
733
+ 'output_dir',
734
+ 'output_name',
735
+ 'timestep_sampling',
736
+ 'discrete_flow_shift',
737
+ 'model_prediction_type',
738
+ 'guidance_scale',
739
+ 'loss_type',
740
+ 'optimizer_type',
741
+ 'optimizer_args',
742
+ 'lr_scheduler',
743
+ 'sample_prompts',
744
+ 'sample_every_n_steps',
745
+ 'max_grad_norm',
746
+ 'split_mode',
747
+ 'network_args'
748
+ }
749
+
750
+ # generate a UI config
751
+ # if not in basic_args, create a simple form
752
+ parser = train_network.setup_parser()
753
+ flux_train_utils.add_flux_train_arguments(parser)
754
+ args_info = {}
755
+ for action in parser._actions:
756
+ if action.dest != 'help': # Skip the default help argument
757
+ # if the dest is included in basic_args
758
+ args_info[action.dest] = {
759
+ "action": action.option_strings, # Option strings like '--use_8bit_adam'
760
+ "type": action.type, # Type of the argument
761
+ "help": action.help, # Help message
762
+ "default": action.default, # Default value, if any
763
+ "required": action.required # Whether the argument is required
764
+ }
765
+ temp = []
766
+ for key in args_info:
767
+ temp.append({ 'key': key, 'action': args_info[key] })
768
+ temp.sort(key=lambda x: x['key'])
769
+ advanced_component_ids = []
770
+ advanced_components = []
771
+ for item in temp:
772
+ key = item['key']
773
+ action = item['action']
774
+ if key in basic_args:
775
+ print("")
776
+ else:
777
+ action_type = str(action['type'])
778
+ component = None
779
+ with gr.Column(min_width=300):
780
+ if action_type == "None":
781
+ # radio
782
+ component = gr.Checkbox()
783
+ # elif action_type == "<class 'str'>":
784
+ # component = gr.Textbox()
785
+ # elif action_type == "<class 'int'>":
786
+ # component = gr.Number(precision=0)
787
+ # elif action_type == "<class 'float'>":
788
+ # component = gr.Number()
789
+ # elif "int_or_float" in action_type:
790
+ # component = gr.Number()
791
+ else:
792
+ component = gr.Textbox(value="")
793
+ if component != None:
794
+ component.interactive = True
795
+ component.elem_id = action['action'][0]
796
+ component.label = component.elem_id
797
+ component.elem_classes = ["advanced"]
798
+ if action['help'] != None:
799
+ component.info = action['help']
800
+ advanced_components.append(component)
801
+ advanced_component_ids.append(component.elem_id)
802
+ return advanced_components, advanced_component_ids
803
+
804
+
805
+ theme = gr.themes.Monochrome(
806
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
807
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
808
+ )
809
+ css = """
810
+ @keyframes rotate {
811
+ 0% {
812
+ transform: rotate(0deg);
813
+ }
814
+ 100% {
815
+ transform: rotate(360deg);
816
+ }
817
+ }
818
+ #advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; }
819
+ h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;}
820
+ h3{margin-top: 0}
821
+ .tabitem{border: 0px}
822
+ .group_padding{}
823
+ nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); }
824
+ nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; }
825
+ nav img { height: 40px; width: 40px; border-radius: 40px; }
826
+ nav img.rotate { animation: rotate 2s linear infinite; }
827
+ .flexible { flex-grow: 1; }
828
+ .tast-details { margin: 10px 0 !important; }
829
+ .toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); }
830
+ .toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; }
831
+ .toast-body { border: none !important; }
832
+ #terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); }
833
+ #terminal .generating { border: none !important; }
834
+ #terminal label { position: absolute !important; }
835
+ .tabs { margin-top: 50px; }
836
+ .hidden { display: none !important; }
837
+ .codemirror-wrapper .cm-line { font-size: 12px !important; }
838
+ label { font-weight: bold !important; }
839
+ #start_training.clicked { background: silver; color: black; }
840
+ """
841
+
842
+ js = """
843
+ function() {
844
+ let autoscroll = document.querySelector("#autoscroll")
845
+ if (window.iidxx) {
846
+ window.clearInterval(window.iidxx);
847
+ }
848
+ window.iidxx = window.setInterval(function() {
849
+ let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim()
850
+ let img = document.querySelector("#logo")
851
+ if (text.length > 0) {
852
+ autoscroll.classList.remove("hidden")
853
+ if (autoscroll.classList.contains("on")) {
854
+ autoscroll.textContent = "Autoscroll ON"
855
+ window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" });
856
+ img.classList.add("rotate")
857
+ } else {
858
+ autoscroll.textContent = "Autoscroll OFF"
859
+ img.classList.remove("rotate")
860
+ }
861
+ }
862
+ }, 500);
863
+ console.log("autoscroll", autoscroll)
864
+ autoscroll.addEventListener("click", (e) => {
865
+ autoscroll.classList.toggle("on")
866
+ })
867
+ function debounce(fn, delay) {
868
+ let timeoutId;
869
+ return function(...args) {
870
+ clearTimeout(timeoutId);
871
+ timeoutId = setTimeout(() => fn(...args), delay);
872
+ };
873
+ }
874
+
875
+ function handleClick() {
876
+ console.log("refresh")
877
+ document.querySelector("#refresh").click();
878
+ }
879
+ const debouncedClick = debounce(handleClick, 1000);
880
+ document.addEventListener("input", debouncedClick);
881
+
882
+ document.querySelector("#start_training").addEventListener("click", (e) => {
883
+ e.target.classList.add("clicked")
884
+ e.target.innerHTML = "Training..."
885
+ })
886
+
887
+ }
888
+ """
889
+
890
+ current_account = account_hf()
891
+ print(f"current_account={current_account}")
892
+
893
+ with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo:
894
+ with gr.Tabs() as tabs:
895
+ with gr.TabItem("Gym"):
896
+ output_components = []
897
+ with gr.Row():
898
+ gr.HTML("""<nav>
899
+ <img id='logo' src='/file=icon.png' width='80' height='80'>
900
+ <div class='flexible'></div>
901
+ <button id='autoscroll' class='on hidden'></button>
902
+ </nav>
903
+ """)
904
+ with gr.Row(elem_id='container'):
905
+ with gr.Column():
906
+ gr.Markdown(
907
+ """# Step 1. LoRA Info
908
+ <p style="margin-top:0">Configure your LoRA train settings.</p>
909
+ """, elem_classes="group_padding")
910
+ lora_name = gr.Textbox(
911
+ label="The name of your LoRA",
912
+ info="This has to be a unique name",
913
+ placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
914
+ )
915
+ concept_sentence = gr.Textbox(
916
+ elem_id="--concept_sentence",
917
+ label="Trigger word/sentence",
918
+ info="Trigger word or sentence to be used",
919
+ placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
920
+ interactive=True,
921
+ )
922
+ model_names = list(models.keys())
923
+ print(f"model_names={model_names}")
924
+ base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0])
925
+ vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True)
926
+ num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True)
927
+ max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True)
928
+ total_steps = gr.Number(0, interactive=False, label="Expected training steps")
929
+ sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True)
930
+ sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True)
931
+ resolution = gr.Number(value=512, precision=0, label="Resize dataset images")
932
+ with gr.Column():
933
+ gr.Markdown(
934
+ """# Step 2. Dataset
935
+ <p style="margin-top:0">Make sure the captions include the trigger word.</p>
936
+ """, elem_classes="group_padding")
937
+ with gr.Group():
938
+ images = gr.File(
939
+ file_types=["image", ".txt"],
940
+ label="Upload your images",
941
+ #info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)",
942
+ file_count="multiple",
943
+ interactive=True,
944
+ visible=True,
945
+ scale=1,
946
+ )
947
+ with gr.Group(visible=False) as captioning_area:
948
+ do_captioning = gr.Button("Add AI captions with Florence-2")
949
+ output_components.append(captioning_area)
950
+ #output_components = [captioning_area]
951
+ caption_list = []
952
+ for i in range(1, MAX_IMAGES + 1):
953
+ locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
954
+ with locals()[f"captioning_row_{i}"]:
955
+ locals()[f"image_{i}"] = gr.Image(
956
+ type="filepath",
957
+ width=111,
958
+ height=111,
959
+ min_width=111,
960
+ interactive=False,
961
+ scale=2,
962
+ show_label=False,
963
+ show_share_button=False,
964
+ show_download_button=False,
965
+ )
966
+ locals()[f"caption_{i}"] = gr.Textbox(
967
+ label=f"Caption {i}", scale=15, interactive=True
968
+ )
969
+
970
+ output_components.append(locals()[f"captioning_row_{i}"])
971
+ output_components.append(locals()[f"image_{i}"])
972
+ output_components.append(locals()[f"caption_{i}"])
973
+ caption_list.append(locals()[f"caption_{i}"])
974
+ with gr.Column():
975
+ gr.Markdown(
976
+ """# Step 3. Train
977
+ <p style="margin-top:0">Press start to start training.</p>
978
+ """, elem_classes="group_padding")
979
+ refresh = gr.Button("Refresh", elem_id="refresh", visible=False)
980
+ start = gr.Button("Start training", visible=False, elem_id="start_training")
981
+ output_components.append(start)
982
+ train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True)
983
+ train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True)
984
+ with gr.Accordion("Advanced options", elem_id='advanced_options', open=False):
985
+ with gr.Row():
986
+ with gr.Column(min_width=300):
987
+ seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True)
988
+ with gr.Column(min_width=300):
989
+ workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True)
990
+ with gr.Column(min_width=300):
991
+ learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True)
992
+ with gr.Column(min_width=300):
993
+ save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True)
994
+ with gr.Column(min_width=300):
995
+ guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True)
996
+ with gr.Column(min_width=300):
997
+ timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True)
998
+ with gr.Column(min_width=300):
999
+ network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True)
1000
+ advanced_components, advanced_component_ids = init_advanced()
1001
+ with gr.Row():
1002
+ terminal = LogsView(label="Train log", elem_id="terminal")
1003
+ with gr.Row():
1004
+ gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6)
1005
+
1006
+ with gr.TabItem("Publish") as publish_tab:
1007
+ hf_token = gr.Textbox(label="Huggingface Token")
1008
+ hf_login = gr.Button("Login")
1009
+ hf_logout = gr.Button("Logout")
1010
+ with gr.Row() as row:
1011
+ gr.Markdown("**LoRA**")
1012
+ gr.Markdown("**Upload**")
1013
+ loras = get_loras()
1014
+ with gr.Row():
1015
+ lora_rows = refresh_publish_tab()
1016
+ with gr.Column():
1017
+ with gr.Row():
1018
+ repo_owner = gr.Textbox(label="Account", interactive=False)
1019
+ repo_name = gr.Textbox(label="Repository Name")
1020
+ repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public")
1021
+ upload_button = gr.Button("Upload to HuggingFace")
1022
+ upload_button.click(
1023
+ fn=upload_hf,
1024
+ inputs=[
1025
+ base_model,
1026
+ lora_rows,
1027
+ repo_owner,
1028
+ repo_name,
1029
+ repo_visibility,
1030
+ hf_token,
1031
+ ]
1032
+ )
1033
+ hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
1034
+ hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1035
+
1036
+
1037
+ publish_tab.select(refresh_publish_tab, outputs=lora_rows)
1038
+ lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name])
1039
+
1040
+ dataset_folder = gr.State()
1041
+
1042
+ listeners = [
1043
+ base_model,
1044
+ lora_name,
1045
+ resolution,
1046
+ seed,
1047
+ workers,
1048
+ concept_sentence,
1049
+ learning_rate,
1050
+ network_dim,
1051
+ max_train_epochs,
1052
+ save_every_n_epochs,
1053
+ timestep_sampling,
1054
+ guidance_scale,
1055
+ vram,
1056
+ num_repeats,
1057
+ sample_prompts,
1058
+ sample_every_n_steps,
1059
+ *advanced_components
1060
+ ]
1061
+ advanced_component_ids = [x.elem_id for x in advanced_components]
1062
+ original_advanced_component_values = [comp.value for comp in advanced_components]
1063
+ images.upload(
1064
+ load_captioning,
1065
+ inputs=[images, concept_sentence],
1066
+ outputs=output_components
1067
+ )
1068
+ images.delete(
1069
+ load_captioning,
1070
+ inputs=[images, concept_sentence],
1071
+ outputs=output_components
1072
+ )
1073
+ images.clear(
1074
+ hide_captioning,
1075
+ outputs=[captioning_area, start]
1076
+ )
1077
+ max_train_epochs.change(
1078
+ fn=update_total_steps,
1079
+ inputs=[max_train_epochs, num_repeats, images],
1080
+ outputs=[total_steps]
1081
+ )
1082
+ num_repeats.change(
1083
+ fn=update_total_steps,
1084
+ inputs=[max_train_epochs, num_repeats, images],
1085
+ outputs=[total_steps]
1086
+ )
1087
+ images.upload(
1088
+ fn=update_total_steps,
1089
+ inputs=[max_train_epochs, num_repeats, images],
1090
+ outputs=[total_steps]
1091
+ )
1092
+ images.delete(
1093
+ fn=update_total_steps,
1094
+ inputs=[max_train_epochs, num_repeats, images],
1095
+ outputs=[total_steps]
1096
+ )
1097
+ images.clear(
1098
+ fn=update_total_steps,
1099
+ inputs=[max_train_epochs, num_repeats, images],
1100
+ outputs=[total_steps]
1101
+ )
1102
+ concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts)
1103
+ start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then(
1104
+ fn=start_training,
1105
+ inputs=[
1106
+ base_model,
1107
+ lora_name,
1108
+ train_script,
1109
+ train_config,
1110
+ sample_prompts,
1111
+ ],
1112
+ outputs=terminal,
1113
+ )
1114
+ do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
1115
+ demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1116
+ refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder])
1117
+ if __name__ == "__main__":
1118
+ cwd = os.path.dirname(os.path.abspath(__file__))
1119
+ demo.launch(debug=True, show_error=True, allowed_paths=[cwd])