Dickson commited on
Commit
58c852b
1 Parent(s): 05d9d9c
Files changed (2) hide show
  1. app.py +1074 -0
  2. requirements_local.txt +14 -0
app.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import requests
4
+ import subprocess
5
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
6
+ from huggingface_hub import snapshot_download, HfApi
7
+ import torch
8
+ import uuid
9
+ import os
10
+ import shutil
11
+ import json
12
+ import random
13
+ from slugify import slugify
14
+ import argparse
15
+ import importlib
16
+ import sys
17
+ from pathlib import Path
18
+ import spaces
19
+ import zipfile
20
+
21
+ MAX_IMAGES = 150
22
+
23
+ is_spaces = True if os.environ.get('SPACE_ID') else False
24
+
25
+ training_script_url = "https://raw.githubusercontent.com/huggingface/diffusers/ba28006f8b2a0f7ec3b6784695790422b4f80a97/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py"
26
+ subprocess.run(['wget', '-N', training_script_url])
27
+ orchestrator_script_url = "https://huggingface.co/datasets/multimodalart/lora-ease-helper/raw/main/script.py"
28
+ subprocess.run(['wget', '-N', orchestrator_script_url])
29
+
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ FACES_DATASET_PATH = snapshot_download(repo_id="multimodalart/faces-prior-preservation", repo_type="dataset")
33
+ #Delete .gitattributes to process things properly
34
+ Path(FACES_DATASET_PATH, '.gitattributes').unlink(missing_ok=True)
35
+
36
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
37
+ model = Blip2ForConditionalGeneration.from_pretrained(
38
+ "Salesforce/blip2-opt-2.7b", device_map={"": 0}, torch_dtype=torch.float16
39
+ )
40
+
41
+ training_option_settings = {
42
+ "face": {
43
+ "rank": 32,
44
+ "lr_scheduler": "constant",
45
+ "with_prior_preservation": True,
46
+ "class_prompt": "a photo of a person",
47
+ "train_steps_multiplier": 75,
48
+ "file_count": 150,
49
+ "dataset_path": FACES_DATASET_PATH
50
+ },
51
+ "style": {
52
+ "rank": 32,
53
+ "lr_scheduler": "constant",
54
+ "with_prior_preservation": False,
55
+ "class_prompt": "",
56
+ "train_steps_multiplier": 120
57
+ },
58
+ "character": {
59
+ "rank": 32,
60
+ "lr_scheduler": "constant",
61
+ "with_prior_preservation": False,
62
+ "class_prompt": "",
63
+ "train_steps_multiplier": 180
64
+ },
65
+ "object": {
66
+ "rank": 16,
67
+ "lr_scheduler": "constant",
68
+ "with_prior_preservation": False,
69
+ "class_prompt": "",
70
+ "train_steps_multiplier": 50
71
+ },
72
+ "custom": {
73
+ "rank": 32,
74
+ "lr_scheduler": "constant",
75
+ "with_prior_preservation": False,
76
+ "class_prompt": "",
77
+ "train_steps_multiplier": 150
78
+ }
79
+ }
80
+
81
+ num_images_settings = {
82
+ #>24 images, 1 repeat; 10<x<24 images 2 repeats; <10 images 3 repeats
83
+ "repeats": [(24, 1), (10, 2), (0, 3)],
84
+ "train_steps_min": 500,
85
+ "train_steps_max": 1500
86
+ }
87
+
88
+ def load_captioning(uploaded_images, option):
89
+ updates = []
90
+ if len(uploaded_images) <= 1:
91
+ raise gr.Error(
92
+ "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
93
+ )
94
+ elif len(uploaded_images) > MAX_IMAGES:
95
+ raise gr.Error(
96
+ f"For now, only {MAX_IMAGES} or less images are allowed for training"
97
+ )
98
+ # Update for the captioning_area
99
+ for _ in range(3):
100
+ updates.append(gr.update(visible=True))
101
+ # Update visibility and image for each captioning row and image
102
+ for i in range(1, MAX_IMAGES + 1):
103
+ # Determine if the current row and image should be visible
104
+ visible = i <= len(uploaded_images)
105
+
106
+ # Update visibility of the captioning row
107
+ updates.append(gr.update(visible=visible))
108
+
109
+ # Update for image component - display image if available, otherwise hide
110
+ image_value = uploaded_images[i - 1] if visible else None
111
+ updates.append(gr.update(value=image_value, visible=visible))
112
+
113
+ text_value = option if visible else None
114
+ updates.append(gr.update(value=text_value, visible=visible))
115
+ return updates
116
+
117
+ def check_removed_and_restart(images):
118
+ visible = len(images) > 1 if images is not None else False
119
+ if(is_spaces):
120
+ captioning_area = gr.update(visible=visible)
121
+ advanced = gr.update(visible=visible)
122
+ cost_estimation = gr.update(visible=visible)
123
+ start = gr.update(visible=False)
124
+ else:
125
+ captioning_area = gr.update(visible=visible)
126
+ advanced = gr.update(visible=visible)
127
+ cost_estimation = gr.update(visible=False)
128
+ start = gr.update(visible=True)
129
+ return captioning_area, advanced,cost_estimation, start
130
+
131
+ def make_options_visible(option):
132
+ if (option == "object") or (option == "face"):
133
+ sentence = "A photo of TOK"
134
+ elif option == "style":
135
+ sentence = "in the style of TOK"
136
+ elif option == "character":
137
+ sentence = "A TOK character"
138
+ elif option == "custom":
139
+ sentence = "TOK"
140
+ return (
141
+ gr.update(value=sentence, visible=True),
142
+ gr.update(visible=True),
143
+ )
144
+
145
+ def change_defaults(option, images):
146
+ settings = training_option_settings.get(option, training_option_settings["custom"])
147
+ num_images = len(images)
148
+
149
+ # Calculate max_train_steps
150
+ train_steps_multiplier = settings["train_steps_multiplier"]
151
+ max_train_steps = max(num_images * train_steps_multiplier, num_images_settings["train_steps_min"])
152
+ max_train_steps = min(max_train_steps, num_images_settings["train_steps_max"])
153
+
154
+ # Determine repeats based on number of images
155
+ repeats = next(repeats for num, repeats in num_images_settings["repeats"] if num_images > num)
156
+
157
+ random_files = []
158
+ if settings["with_prior_preservation"]:
159
+ directory = settings["dataset_path"]
160
+ file_count = settings["file_count"]
161
+ files = [os.path.join(directory, file) for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))]
162
+ random_files = random.sample(files, min(len(files), file_count))
163
+
164
+ return max_train_steps, repeats, settings["lr_scheduler"], settings["rank"], settings["with_prior_preservation"], settings["class_prompt"], random_files
165
+
166
+ def create_dataset(*inputs):
167
+ print("Creating dataset")
168
+ images = inputs[0]
169
+ destination_folder = str(uuid.uuid4())
170
+ if not os.path.exists(destination_folder):
171
+ os.makedirs(destination_folder)
172
+
173
+ jsonl_file_path = os.path.join(destination_folder, 'metadata.jsonl')
174
+ with open(jsonl_file_path, 'a') as jsonl_file:
175
+ for index, image in enumerate(images):
176
+ new_image_path = shutil.copy(image, destination_folder)
177
+
178
+ original_caption = inputs[index + 1]
179
+ file_name = os.path.basename(new_image_path)
180
+
181
+ data = {"file_name": file_name, "prompt": original_caption}
182
+
183
+ jsonl_file.write(json.dumps(data) + "\n")
184
+
185
+ return destination_folder
186
+
187
+ def start_training(
188
+ lora_name,
189
+ training_option,
190
+ concept_sentence,
191
+ optimizer,
192
+ use_snr_gamma,
193
+ snr_gamma,
194
+ mixed_precision,
195
+ learning_rate,
196
+ train_batch_size,
197
+ max_train_steps,
198
+ lora_rank,
199
+ repeats,
200
+ with_prior_preservation,
201
+ class_prompt,
202
+ class_images,
203
+ num_class_images,
204
+ train_text_encoder_ti,
205
+ train_text_encoder_ti_frac,
206
+ num_new_tokens_per_abstraction,
207
+ train_text_encoder,
208
+ train_text_encoder_frac,
209
+ text_encoder_learning_rate,
210
+ seed,
211
+ resolution,
212
+ num_train_epochs,
213
+ checkpointing_steps,
214
+ prior_loss_weight,
215
+ gradient_accumulation_steps,
216
+ gradient_checkpointing,
217
+ enable_xformers_memory_efficient_attention,
218
+ adam_beta1,
219
+ adam_beta2,
220
+ use_prodigy_beta3,
221
+ prodigy_beta3,
222
+ prodigy_decouple,
223
+ adam_weight_decay,
224
+ use_adam_weight_decay_text_encoder,
225
+ adam_weight_decay_text_encoder,
226
+ adam_epsilon,
227
+ prodigy_use_bias_correction,
228
+ prodigy_safeguard_warmup,
229
+ max_grad_norm,
230
+ scale_lr,
231
+ lr_num_cycles,
232
+ lr_scheduler,
233
+ lr_power,
234
+ lr_warmup_steps,
235
+ dataloader_num_workers,
236
+ local_rank,
237
+ dataset_folder,
238
+ token,
239
+ progress = gr.Progress(track_tqdm=True)
240
+ ):
241
+ if not lora_name:
242
+ raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
243
+ print("Started training")
244
+ slugged_lora_name = slugify(lora_name)
245
+ spacerunner_folder = str(uuid.uuid4())
246
+ commands = [
247
+ "pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
248
+ "pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
249
+ f"instance_prompt={concept_sentence}",
250
+ f"dataset_name=./{dataset_folder}",
251
+ "caption_column=prompt",
252
+ f"output_dir={slugged_lora_name}",
253
+ f"mixed_precision={mixed_precision}",
254
+ f"resolution={int(resolution)}",
255
+ f"train_batch_size={int(train_batch_size)}",
256
+ f"repeats={int(repeats)}",
257
+ f"gradient_accumulation_steps={int(gradient_accumulation_steps)}",
258
+ f"learning_rate={learning_rate}",
259
+ f"text_encoder_lr={text_encoder_learning_rate}",
260
+ f"adam_beta1={adam_beta1}",
261
+ f"adam_beta2={adam_beta2}",
262
+ f"optimizer={'adamW' if optimizer == '8bitadam' else optimizer}",
263
+ f"train_text_encoder_ti_frac={train_text_encoder_ti_frac}",
264
+ f"lr_scheduler={lr_scheduler}",
265
+ f"lr_warmup_steps={int(lr_warmup_steps)}",
266
+ f"rank={int(lora_rank)}",
267
+ f"max_train_steps={int(max_train_steps)}",
268
+ f"checkpointing_steps={int(checkpointing_steps)}",
269
+ f"seed={int(seed)}",
270
+ f"prior_loss_weight={prior_loss_weight}",
271
+ f"num_new_tokens_per_abstraction={int(num_new_tokens_per_abstraction)}",
272
+ f"num_train_epochs={int(num_train_epochs)}",
273
+ f"adam_weight_decay={adam_weight_decay}",
274
+ f"adam_epsilon={adam_epsilon}",
275
+ f"prodigy_decouple={prodigy_decouple}",
276
+ f"prodigy_use_bias_correction={prodigy_use_bias_correction}",
277
+ f"prodigy_safeguard_warmup={prodigy_safeguard_warmup}",
278
+ f"max_grad_norm={max_grad_norm}",
279
+ f"lr_num_cycles={int(lr_num_cycles)}",
280
+ f"lr_power={lr_power}",
281
+ f"dataloader_num_workers={int(dataloader_num_workers)}",
282
+ f"local_rank={int(local_rank)}",
283
+ "cache_latents",
284
+ #"push_to_hub",
285
+ ]
286
+ # Adding optional flags
287
+ if optimizer == "8bitadam":
288
+ commands.append("use_8bit_adam")
289
+ if gradient_checkpointing:
290
+ commands.append("gradient_checkpointing")
291
+
292
+ if train_text_encoder_ti:
293
+ commands.append("train_text_encoder_ti")
294
+ elif train_text_encoder:
295
+ commands.append("train_text_encoder")
296
+ commands.append(f"train_text_encoder_frac={train_text_encoder_frac}")
297
+ if enable_xformers_memory_efficient_attention:
298
+ commands.append("enable_xformers_memory_efficient_attention")
299
+ if use_snr_gamma:
300
+ commands.append(f"snr_gamma={snr_gamma}")
301
+ if scale_lr:
302
+ commands.append("scale_lr")
303
+ if with_prior_preservation:
304
+ commands.append("with_prior_preservation")
305
+ commands.append(f"class_prompt={class_prompt}")
306
+ commands.append(f"num_class_images={int(num_class_images)}")
307
+ if class_images:
308
+ class_folder = str(uuid.uuid4())
309
+ zip_path = os.path.join(spacerunner_folder, class_folder, "class_images.zip")
310
+
311
+ if not os.path.exists(os.path.join(spacerunner_folder, class_folder)):
312
+ os.makedirs(os.path.join(spacerunner_folder, class_folder))
313
+
314
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
315
+ for image in class_images:
316
+ zipf.write(image, os.path.basename(image))
317
+
318
+ commands.append(f"class_data_dir={class_folder}")
319
+ if use_prodigy_beta3:
320
+ commands.append(f"prodigy_beta3={prodigy_beta3}")
321
+ if use_adam_weight_decay_text_encoder:
322
+ commands.append(f"adam_weight_decay_text_encoder={adam_weight_decay_text_encoder}")
323
+ print(commands)
324
+ # Joining the commands with ';' separator for spacerunner format
325
+ spacerunner_args = ';'.join(commands)
326
+ if not os.path.exists(spacerunner_folder):
327
+ os.makedirs(spacerunner_folder)
328
+ shutil.copy("train_dreambooth_lora_sdxl_advanced.py", f"{spacerunner_folder}/trainer.py")
329
+ shutil.copy("script.py", f"{spacerunner_folder}/script.py")
330
+ shutil.copytree(dataset_folder, f"{spacerunner_folder}/{dataset_folder}")
331
+ requirements='''peft==0.7.1
332
+ -huggingface_hub
333
+ torch
334
+ git+https://github.com/huggingface/diffusers@ba28006f8b2a0f7ec3b6784695790422b4f80a97
335
+ transformers==4.36.2
336
+ accelerate==0.25.0
337
+ safetensors==0.4.1
338
+ prodigyopt==1.0
339
+ hf-transfer==0.1.4
340
+ huggingface_hub==0.20.3
341
+ git+https://github.com/huggingface/datasets.git@3f149204a2a5948287adcade5e90707aa5207a92'''
342
+ file_path = f'{spacerunner_folder}/requirements.txt'
343
+ with open(file_path, 'w') as file:
344
+ file.write(requirements)
345
+ # The subprocess call for autotrain spacerunner
346
+ api = HfApi(token=token)
347
+ username = api.whoami()["name"]
348
+ subprocess_command = ["autotrain", "spacerunner", "--project-name", slugged_lora_name, "--script-path", spacerunner_folder, "--username", username, "--token", token, "--backend", "spaces-a10g-small", "--env",f"HF_TOKEN={token};HF_HUB_ENABLE_HF_TRANSFER=1", "--args", spacerunner_args]
349
+ outcome = subprocess.run(subprocess_command)
350
+ if(outcome.returncode == 0):
351
+ return f"""# Your training has started.
352
+ ## - Training Status: <a href='https://huggingface.co/spaces/{username}/autotrain-{slugged_lora_name}?logs=container'>{username}/autotrain-{slugged_lora_name}</a> <small>(in the logs tab)</small>
353
+ ## - Model page: <a href='https://huggingface.co/{username}/{slugged_lora_name}'>{username}/{slugged_lora_name}</a> <small>(will be available when training finishes)</small>"""
354
+ else:
355
+ print("Error: ", outcome.stderr)
356
+ raise gr.Error("Something went wrong. Make sure the name of your LoRA is unique and try again")
357
+
358
+ def calculate_price(iterations, with_prior_preservation):
359
+ if(with_prior_preservation):
360
+ seconds_per_iteration = 3.50
361
+ else:
362
+ seconds_per_iteration = 2.00
363
+ total_seconds = (iterations * seconds_per_iteration) + 210
364
+ cost_per_second = 1.05/60/60
365
+ cost = round(cost_per_second * total_seconds, 2)
366
+ return f'''To train this LoRA, we will duplicate the space and hook an A10G GPU under the hood.
367
+ ## Estimated to cost <b>< US$ {str(cost)}</b> for {round(int(total_seconds)/60, 2)} minutes with your current train settings <small>({int(iterations)} iterations at {seconds_per_iteration}s/it)</small>
368
+ #### ↓ to continue, grab you <b>write</b> token [here](https://huggingface.co/settings/tokens) and enter it below ↓'''
369
+
370
+ def start_training_og(
371
+ lora_name,
372
+ training_option,
373
+ concept_sentence,
374
+ optimizer,
375
+ use_snr_gamma,
376
+ snr_gamma,
377
+ mixed_precision,
378
+ learning_rate,
379
+ train_batch_size,
380
+ max_train_steps,
381
+ lora_rank,
382
+ repeats,
383
+ with_prior_preservation,
384
+ class_prompt,
385
+ class_images,
386
+ num_class_images,
387
+ train_text_encoder_ti,
388
+ train_text_encoder_ti_frac,
389
+ num_new_tokens_per_abstraction,
390
+ train_text_encoder,
391
+ train_text_encoder_frac,
392
+ text_encoder_learning_rate,
393
+ seed,
394
+ resolution,
395
+ num_train_epochs,
396
+ checkpointing_steps,
397
+ prior_loss_weight,
398
+ gradient_accumulation_steps,
399
+ gradient_checkpointing,
400
+ enable_xformers_memory_efficient_attention,
401
+ adam_beta1,
402
+ adam_beta2,
403
+ use_prodigy_beta3,
404
+ prodigy_beta3,
405
+ prodigy_decouple,
406
+ adam_weight_decay,
407
+ use_adam_weight_decay_text_encoder,
408
+ adam_weight_decay_text_encoder,
409
+ adam_epsilon,
410
+ prodigy_use_bias_correction,
411
+ prodigy_safeguard_warmup,
412
+ max_grad_norm,
413
+ scale_lr,
414
+ lr_num_cycles,
415
+ lr_scheduler,
416
+ lr_power,
417
+ lr_warmup_steps,
418
+ dataloader_num_workers,
419
+ local_rank,
420
+ dataset_folder,
421
+ token,
422
+ #progress = gr.Progress(track_tqdm=True)
423
+ ):
424
+ if not lora_name:
425
+ raise gr.Error("You forgot to insert your LoRA name!")
426
+ slugged_lora_name = slugify(lora_name)
427
+ commands = [
428
+ "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
429
+ "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
430
+ f"--instance_prompt={concept_sentence}",
431
+ f"--dataset_name=./{dataset_folder}",
432
+ "--caption_column=prompt",
433
+ f"--output_dir={slugged_lora_name}",
434
+ f"--mixed_precision={mixed_precision}",
435
+ f"--resolution={int(resolution)}",
436
+ f"--train_batch_size={int(train_batch_size)}",
437
+ f"--repeats={int(repeats)}",
438
+ f"--gradient_accumulation_steps={int(gradient_accumulation_steps)}",
439
+ f"--learning_rate={learning_rate}",
440
+ f"--text_encoder_lr={text_encoder_learning_rate}",
441
+ f"--adam_beta1={adam_beta1}",
442
+ f"--adam_beta2={adam_beta2}",
443
+ f"--optimizer={'adamW' if optimizer == '8bitadam' else optimizer}",
444
+ f"--train_text_encoder_ti_frac={train_text_encoder_ti_frac}",
445
+ f"--lr_scheduler={lr_scheduler}",
446
+ f"--lr_warmup_steps={int(lr_warmup_steps)}",
447
+ f"--rank={int(lora_rank)}",
448
+ f"--max_train_steps={int(max_train_steps)}",
449
+ f"--checkpointing_steps={int(checkpointing_steps)}",
450
+ f"--seed={int(seed)}",
451
+ f"--prior_loss_weight={prior_loss_weight}",
452
+ f"--num_new_tokens_per_abstraction={int(num_new_tokens_per_abstraction)}",
453
+ f"--num_train_epochs={int(num_train_epochs)}",
454
+ f"--adam_weight_decay={adam_weight_decay}",
455
+ f"--adam_epsilon={adam_epsilon}",
456
+ f"--prodigy_decouple={prodigy_decouple}",
457
+ f"--prodigy_use_bias_correction={prodigy_use_bias_correction}",
458
+ f"--prodigy_safeguard_warmup={prodigy_safeguard_warmup}",
459
+ f"--max_grad_norm={max_grad_norm}",
460
+ f"--lr_num_cycles={int(lr_num_cycles)}",
461
+ f"--lr_power={lr_power}",
462
+ f"--dataloader_num_workers={int(dataloader_num_workers)}",
463
+ f"--local_rank={int(local_rank)}",
464
+ "--cache_latents"
465
+ ]
466
+ if optimizer == "8bitadam":
467
+ commands.append("--use_8bit_adam")
468
+ if gradient_checkpointing:
469
+ commands.append("--gradient_checkpointing")
470
+
471
+ if train_text_encoder_ti:
472
+ commands.append("--train_text_encoder_ti")
473
+ elif train_text_encoder:
474
+ commands.append("--train_text_encoder")
475
+ commands.append(f"--train_text_encoder_frac={train_text_encoder_frac}")
476
+ if enable_xformers_memory_efficient_attention:
477
+ commands.append("--enable_xformers_memory_efficient_attention")
478
+ if use_snr_gamma:
479
+ commands.append(f"--snr_gamma={snr_gamma}")
480
+ if scale_lr:
481
+ commands.append("--scale_lr")
482
+ if with_prior_preservation:
483
+ commands.append(f"--with_prior_preservation")
484
+ commands.append(f"--class_prompt={class_prompt}")
485
+ commands.append(f"--num_class_images={int(num_class_images)}")
486
+ if(class_images):
487
+ class_folder = str(uuid.uuid4())
488
+ if not os.path.exists(class_folder):
489
+ os.makedirs(class_folder)
490
+ for image in class_images:
491
+ shutil.copy(image, class_folder)
492
+ commands.append(f"--class_data_dir={class_folder}")
493
+ if use_prodigy_beta3:
494
+ commands.append(f"--prodigy_beta3={prodigy_beta3}")
495
+ if use_adam_weight_decay_text_encoder:
496
+ commands.append(f"--adam_weight_decay_text_encoder={adam_weight_decay_text_encoder}")
497
+ from train_dreambooth_lora_sdxl_advanced import main as train_main, parse_args as parse_train_args
498
+ args = parse_train_args(commands)
499
+
500
+ train_main(args)
501
+
502
+ return f"Your model has finished training and has been saved to the `{slugged_lora_name}` folder"
503
+
504
+ @spaces.GPU(enable_queue=True)
505
+ def run_captioning(*inputs):
506
+ model.to("cuda")
507
+ images = inputs[0]
508
+ training_option = inputs[-1]
509
+ final_captions = [""] * MAX_IMAGES
510
+ for index, image in enumerate(images):
511
+ original_caption = inputs[index + 1]
512
+ pil_image = Image.open(image)
513
+ blip_inputs = processor(images=pil_image, return_tensors="pt").to(device, torch.float16)
514
+ generated_ids = model.generate(**blip_inputs)
515
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
516
+ if training_option == "style":
517
+ final_caption = generated_text + " " + original_caption
518
+ else:
519
+ final_caption = original_caption + " " + generated_text
520
+ final_captions[index] = final_caption
521
+ yield final_captions
522
+
523
+ def check_token(token):
524
+ try:
525
+ api = HfApi(token=token)
526
+ user_data = api.whoami()
527
+ except Exception as e:
528
+ gr.Warning("Invalid user token. Make sure to get your Hugging Face token from the settings page")
529
+ return gr.update(visible=False), gr.update(visible=False)
530
+ else:
531
+ if (user_data['auth']['accessToken']['role'] != "write"):
532
+ gr.Warning("Ops, you've uploaded a Read token. You need to use a Write token!")
533
+ else:
534
+ if user_data['canPay']:
535
+ return gr.update(visible=False), gr.update(visible=True)
536
+ else:
537
+ return gr.update(visible=True), gr.update(visible=False)
538
+
539
+ return gr.update(visible=False), gr.update(visible=False)
540
+
541
+ def check_if_tok(sentence, textual_inversion):
542
+ if "TOK" not in sentence and textual_inversion:
543
+ gr.Warning("⚠️ You've removed the special token TOK from your concept sentence. This will degrade performance as this special token is needed for textual inversion. Use TOK to describe what you are training.")
544
+
545
+ css = '''.gr-group{background-color: transparent;box-shadow: var(--block-shadow)}
546
+ .gr-group .hide-container{padding: 1em; background: var(--block-background-fill) !important}
547
+ .gr-group img{object-fit: cover}
548
+ #main_title{text-align:center}
549
+ #main_title h1 {font-size: 2.25rem}
550
+ #main_title h3, #main_title p{margin-top: 0;font-size: 1.25em}
551
+ #training_cost h2{margin-top: 10px;padding: 0.5em;border: 1px solid var(--block-border-color);font-size: 1.25em}
552
+ #training_cost h4{margin-top: 1.25em;margin-bottom: 0}
553
+ #training_cost small{font-weight: normal}
554
+ .accordion {color: var(--body-text-color)}
555
+ .main_unlogged{opacity: 0.5;pointer-events: none}
556
+ .login_logout{width: 100% !important}
557
+ #login {font-size: 0px;width: 100% !important;margin: 0 auto}
558
+ #login:after {content: 'Authorize this app to train your model';visibility: visible;display: block;font-size: var(--button-large-text-size)}
559
+ #component-3, component-697{border: 0}
560
+ '''
561
+
562
+ theme = gr.themes.Monochrome(
563
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
564
+ font=[gr.themes.GoogleFont('Source Sans Pro'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
565
+ )
566
+
567
+ with gr.Blocks(css=css, theme=theme) as demo:
568
+ dataset_folder = gr.State()
569
+ gr.Markdown('''# LoRA Ease 🧞‍♂️
570
+ ### Train a high quality SDXL LoRA in a breeze ༄ with state-of-the-art techniques and for cheap
571
+ <small>Dreambooth with Pivotal Tuning, Prodigy and more! Use the trained LoRAs with diffusers, AUTO1111, Comfy. [blog about the training script](https://huggingface.co/blog/sdxl_lora_advanced_script), [Colab Pro](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb), [run locally or in a cloud](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py)</small>.''', elem_id="main_title")
572
+ #gr.LoginButton(elem_classes=["login_logout"])
573
+ with gr.Tab("Train on Spaces"):
574
+ with gr.Column(elem_classes=["main_logged"]) as main_ui:
575
+ lora_name = gr.Textbox(label="The name of your LoRA", info="This has to be a unique name", placeholder="e.g.: Persian Miniature Painting style, Cat Toy")
576
+ training_option = gr.Radio(
577
+ label="What are you training?", choices=["object", "style", "character", "face", "custom"]
578
+ )
579
+ concept_sentence = gr.Textbox(
580
+ label="Concept sentence",
581
+ info="Sentence to be used in all images for captioning. TOK is a special mandatory token, used to teach the model your concept.",
582
+ placeholder="e.g.: A photo of TOK, in the style of TOK",
583
+ visible=False,
584
+ interactive=True,
585
+ )
586
+ with gr.Group(visible=False) as image_upload:
587
+ with gr.Row():
588
+ images = gr.File(
589
+ file_types=["image"],
590
+ label="Upload your images",
591
+ file_count="multiple",
592
+ interactive=True,
593
+ visible=True,
594
+ scale=1,
595
+ )
596
+ with gr.Column(scale=3, visible=False) as captioning_area:
597
+ with gr.Column():
598
+ gr.Markdown(
599
+ """# Custom captioning
600
+ To improve the quality of your outputs, you can add a custom caption for each image, describing exactly what is taking place in each of them. Including TOK is mandatory. You can leave things as is if you don't want to include captioning.
601
+ """
602
+ )
603
+ do_captioning = gr.Button("Add AI captions with BLIP-2")
604
+ output_components = [captioning_area]
605
+ caption_list = []
606
+ for i in range(1, MAX_IMAGES + 1):
607
+ locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
608
+ with locals()[f"captioning_row_{i}"]:
609
+ locals()[f"image_{i}"] = gr.Image(
610
+ width=111,
611
+ height=111,
612
+ min_width=111,
613
+ interactive=False,
614
+ scale=2,
615
+ show_label=False,
616
+ show_share_button=False,
617
+ show_download_button=False
618
+ )
619
+ locals()[f"caption_{i}"] = gr.Textbox(
620
+ label=f"Caption {i}", scale=15, interactive=True
621
+ )
622
+
623
+ output_components.append(locals()[f"captioning_row_{i}"])
624
+ output_components.append(locals()[f"image_{i}"])
625
+ output_components.append(locals()[f"caption_{i}"])
626
+ caption_list.append(locals()[f"caption_{i}"])
627
+ with gr.Accordion(open=False, label="Advanced options", visible=False, elem_classes=['accordion']) as advanced:
628
+ with gr.Row():
629
+ with gr.Column():
630
+ optimizer = gr.Dropdown(
631
+ label="Optimizer",
632
+ info="Prodigy is an auto-optimizer and works good by default. If you prefer to set your own learning rates, change it to AdamW. If you don't have enough VRAM to train with AdamW, pick 8-bit Adam.",
633
+ choices=[
634
+ ("Prodigy", "prodigy"),
635
+ ("AdamW", "adamW"),
636
+ ("8-bit Adam", "8bitadam"),
637
+ ],
638
+ value="prodigy",
639
+ interactive=True,
640
+ )
641
+ use_snr_gamma = gr.Checkbox(label="Use SNR Gamma")
642
+ snr_gamma = gr.Number(
643
+ label="snr_gamma",
644
+ info="SNR weighting gamma to re-balance the loss",
645
+ value=5.000,
646
+ step=0.1,
647
+ visible=False,
648
+ )
649
+ mixed_precision = gr.Dropdown(
650
+ label="Mixed Precision",
651
+ choices=["no", "fp16", "bf16"],
652
+ value="bf16",
653
+ )
654
+ learning_rate = gr.Number(
655
+ label="UNet Learning rate",
656
+ minimum=0.0,
657
+ maximum=10.0,
658
+ step=0.0000001,
659
+ value=1.0, # For prodigy you start high and it will optimize down
660
+ )
661
+ max_train_steps = gr.Number(
662
+ label="Max train steps", minimum=1, maximum=50000, value=1000
663
+ )
664
+ lora_rank = gr.Number(
665
+ label="LoRA Rank",
666
+ info="Rank for the Low Rank Adaptation (LoRA), a higher rank produces a larger LoRA",
667
+ value=8,
668
+ step=2,
669
+ minimum=2,
670
+ maximum=1024,
671
+ )
672
+ repeats = gr.Number(
673
+ label="Repeats",
674
+ info="How many times to repeat the training data.",
675
+ value=1,
676
+ minimum=1,
677
+ maximum=200,
678
+ )
679
+ with gr.Column():
680
+ with_prior_preservation = gr.Checkbox(
681
+ label="Prior preservation loss",
682
+ info="Prior preservation helps to ground the model to things that are similar to your concept. Good for faces.",
683
+ value=False,
684
+ )
685
+ with gr.Column(visible=False) as prior_preservation_params:
686
+ with gr.Tab("prompt"):
687
+ class_prompt = gr.Textbox(
688
+ label="Class Prompt",
689
+ info="The prompt that will be used to generate your class images",
690
+ )
691
+
692
+ with gr.Tab("images"):
693
+ class_images = gr.File(
694
+ file_types=["image"],
695
+ label="Upload your images",
696
+ file_count="multiple",
697
+ )
698
+ num_class_images = gr.Number(
699
+ label="Number of class images, if there are less images uploaded then the number you put here, additional images will be sampled with Class Prompt",
700
+ value=20,
701
+ )
702
+ train_text_encoder_ti = gr.Checkbox(
703
+ label="Do textual inversion",
704
+ value=True,
705
+ info="Will train a textual inversion embedding together with the LoRA. Increases quality significantly. If untoggled, you can remove the special TOK token from the prompts.",
706
+ )
707
+ with gr.Group(visible=True) as pivotal_tuning_params:
708
+ train_text_encoder_ti_frac = gr.Number(
709
+ label="Pivot Textual Inversion",
710
+ info="% of epochs to train textual inversion for",
711
+ value=0.5,
712
+ step=0.1,
713
+ )
714
+ num_new_tokens_per_abstraction = gr.Number(
715
+ label="Tokens to train",
716
+ info="Number of tokens to train in the textual inversion",
717
+ value=2,
718
+ minimum=1,
719
+ maximum=1024,
720
+ interactive=True,
721
+ )
722
+ with gr.Group(visible=False) as text_encoder_train_params:
723
+ train_text_encoder = gr.Checkbox(
724
+ label="Train Text Encoder", value=True
725
+ )
726
+ train_text_encoder_frac = gr.Number(
727
+ label="Pivot Text Encoder",
728
+ info="% of epochs to train the text encoder for",
729
+ value=0.8,
730
+ step=0.1,
731
+ )
732
+ text_encoder_learning_rate = gr.Number(
733
+ label="Text encoder learning rate",
734
+ minimum=0.0,
735
+ maximum=10.0,
736
+ step=0.0000001,
737
+ value=1.0,
738
+ )
739
+ seed = gr.Number(label="Seed", value=42)
740
+ resolution = gr.Number(
741
+ label="Resolution",
742
+ info="Only square sizes are supported for now, the value will be width and height",
743
+ value=1024,
744
+ )
745
+
746
+ with gr.Accordion(open=False, label="Even more advanced options", elem_classes=['accordion']):
747
+ with gr.Row():
748
+ with gr.Column():
749
+ gradient_accumulation_steps = gr.Number(
750
+ info="If you change this setting, the pricing calculation will be wrong",
751
+ label="gradient_accumulation_steps",
752
+ value=1
753
+ )
754
+ train_batch_size = gr.Number(
755
+ info="If you change this setting, the pricing calculation will be wrong",
756
+ label="Train batch size",
757
+ value=2
758
+ )
759
+ num_train_epochs = gr.Number(
760
+ info="If you change this setting, the pricing calculation will be wrong",
761
+ label="num_train_epochs",
762
+ value=1
763
+ )
764
+ checkpointing_steps = gr.Number(
765
+ info="How many steps to save intermediate checkpoints",
766
+ label="checkpointing_steps",
767
+ value=100000,
768
+ visible=False #hack to not let users break this for now
769
+ )
770
+ prior_loss_weight = gr.Number(
771
+ label="prior_loss_weight",
772
+ value=1
773
+ )
774
+ gradient_checkpointing = gr.Checkbox(
775
+ label="gradient_checkpointing",
776
+ info="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass",
777
+ value=True,
778
+ )
779
+ adam_beta1 = gr.Number(
780
+ label="adam_beta1",
781
+ value=0.9,
782
+ minimum=0,
783
+ maximum=1,
784
+ step=0.01
785
+ )
786
+ adam_beta2 = gr.Number(
787
+ label="adam_beta2",
788
+ minimum=0,
789
+ maximum=1,
790
+ step=0.01,
791
+ value=0.999
792
+ )
793
+ use_prodigy_beta3 = gr.Checkbox(
794
+ label="Use Prodigy Beta 3?"
795
+ )
796
+ prodigy_beta3 = gr.Number(
797
+ label="Prodigy Beta 3",
798
+ value=None,
799
+ step=0.01,
800
+ minimum=0,
801
+ maximum=1,
802
+ )
803
+ prodigy_decouple = gr.Checkbox(
804
+ label="Prodigy Decouple",
805
+ value=True
806
+ )
807
+ adam_weight_decay = gr.Number(
808
+ label="Adam Weight Decay",
809
+ value=1e-04,
810
+ step=0.00001,
811
+ minimum=0,
812
+ maximum=1,
813
+ )
814
+ use_adam_weight_decay_text_encoder = gr.Checkbox(
815
+ label="Use Adam Weight Decay Text Encoder"
816
+ )
817
+ adam_weight_decay_text_encoder = gr.Number(
818
+ label="Adam Weight Decay Text Encoder",
819
+ value=None,
820
+ step=0.00001,
821
+ minimum=0,
822
+ maximum=1,
823
+ )
824
+ adam_epsilon = gr.Number(
825
+ label="Adam Epsilon",
826
+ value=1e-08,
827
+ step=0.00000001,
828
+ minimum=0,
829
+ maximum=1,
830
+ )
831
+ prodigy_use_bias_correction = gr.Checkbox(
832
+ label="Prodigy Use Bias Correction",
833
+ value=True
834
+ )
835
+ prodigy_safeguard_warmup = gr.Checkbox(
836
+ label="Prodigy Safeguard Warmup",
837
+ value=True
838
+ )
839
+ max_grad_norm = gr.Number(
840
+ label="Max Grad Norm",
841
+ value=1.0,
842
+ minimum=0.1,
843
+ maximum=10,
844
+ step=0.1,
845
+ )
846
+ enable_xformers_memory_efficient_attention = gr.Checkbox(
847
+ label="enable_xformers_memory_efficient_attention"
848
+ )
849
+ with gr.Column():
850
+ scale_lr = gr.Checkbox(
851
+ label="Scale learning rate",
852
+ info="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size",
853
+ )
854
+ lr_num_cycles = gr.Number(
855
+ label="lr_num_cycles",
856
+ value=1
857
+ )
858
+ lr_scheduler = gr.Dropdown(
859
+ label="lr_scheduler",
860
+ choices=[
861
+ "linear",
862
+ "cosine",
863
+ "cosine_with_restarts",
864
+ "polynomial",
865
+ "constant",
866
+ "constant_with_warmup",
867
+ ],
868
+ value="constant",
869
+ )
870
+ lr_power = gr.Number(
871
+ label="lr_power",
872
+ value=1.0,
873
+ minimum=0.1,
874
+ maximum=10
875
+ )
876
+ lr_warmup_steps = gr.Number(
877
+ label="lr_warmup_steps",
878
+ value=0
879
+ )
880
+ dataloader_num_workers = gr.Number(
881
+ label="Dataloader num workers", value=0, minimum=0, maximum=64
882
+ )
883
+ local_rank = gr.Number(
884
+ label="local_rank",
885
+ value=-1
886
+ )
887
+ with gr.Column(visible=False) as cost_estimation:
888
+ with gr.Group(elem_id="cost_box"):
889
+ training_cost_estimate = gr.Markdown(elem_id="training_cost")
890
+ token = gr.Textbox(label="Your Hugging Face write token", info="A Hugging Face write token you can obtain on the settings page", type="password", placeholder="hf_OhHiThIsIsNoTaReALToKeNGOoDTry")
891
+ with gr.Group(visible=False) as no_payment_method:
892
+ with gr.Row():
893
+ gr.HTML("<h3 style='margin: 0'>Your Hugging Face account doesn't have a payment method set up. Set one up <a href='https://huggingface.co/settings/billing/payment' target='_blank'>here</a> and come back here to train your LoRA</h3>")
894
+ payment_setup = gr.Button("I have set up a payment method")
895
+
896
+ start = gr.Button("Start training", visible=False, interactive=True)
897
+ progress_area = gr.Markdown("")
898
+ with gr.Tab("Train locally"):
899
+ gr.Markdown(f'''To use LoRA Ease locally with a UI, you can clone this repository (yes, HF Spaces are git repos!)
900
+ ```bash
901
+ git clone https://huggingface.co/spaces/multimodalart/lora-ease
902
+ ```
903
+
904
+ Install the dependencies in the `requirements_local.txt` with
905
+
906
+ ```bash
907
+ pip install -r requirements_local.txt
908
+ ```
909
+ (if you prefer, do it in a venv environment)
910
+
911
+ Now you can run LoRA Ease locally by doing a simple
912
+ ```py
913
+ python app.py
914
+ ```
915
+
916
+ If you prefer command line, you can run our [training script]({training_script_url}) yourself.
917
+ ''')
918
+ #gr.LogoutButton(elem_classes=["login_logout"])
919
+ output_components.insert(1, advanced)
920
+ output_components.insert(1, cost_estimation)
921
+ gr.on(
922
+ triggers=[
923
+ token.change,
924
+ payment_setup.click
925
+ ],
926
+ fn=check_token,
927
+ inputs=token,
928
+ outputs=[no_payment_method, start],
929
+ concurrency_limit=50,
930
+ )
931
+ concept_sentence.change(
932
+ check_if_tok,
933
+ inputs=[concept_sentence, train_text_encoder_ti],
934
+ concurrency_limit=50,
935
+ )
936
+ use_snr_gamma.change(
937
+ lambda x: gr.update(visible=x),
938
+ inputs=use_snr_gamma,
939
+ outputs=snr_gamma,
940
+ queue=False,
941
+ )
942
+ with_prior_preservation.change(
943
+ lambda x: gr.update(visible=x),
944
+ inputs=with_prior_preservation,
945
+ outputs=prior_preservation_params,
946
+ queue=False,
947
+ )
948
+ train_text_encoder_ti.change(
949
+ lambda x: gr.update(visible=x),
950
+ inputs=train_text_encoder_ti,
951
+ outputs=pivotal_tuning_params,
952
+ queue=False,
953
+ ).then(
954
+ lambda x: gr.update(visible=(not x)),
955
+ inputs=train_text_encoder_ti,
956
+ outputs=text_encoder_train_params,
957
+ queue=False,
958
+ ).then(
959
+ lambda x: gr.Warning("As you have disabled Pivotal Tuning, you can remove TOK from your prompts and try to find a unique token for them") if not x else None,
960
+ inputs=train_text_encoder_ti,
961
+ concurrency_limit=50,
962
+ )
963
+ train_text_encoder.change(
964
+ lambda x: [gr.update(visible=x), gr.update(visible=x)],
965
+ inputs=train_text_encoder,
966
+ outputs=[train_text_encoder_frac, text_encoder_learning_rate],
967
+ queue=False,
968
+ )
969
+ class_images.change(
970
+ lambda x: gr.update(value=len(x)),
971
+ inputs=class_images,
972
+ outputs=num_class_images,
973
+ queue=False
974
+ )
975
+ images.upload(
976
+ load_captioning,
977
+ inputs=[images, concept_sentence],
978
+ outputs=output_components,
979
+ queue=False
980
+ ).success(
981
+ change_defaults,
982
+ inputs=[training_option, images],
983
+ outputs=[max_train_steps, repeats, lr_scheduler, lora_rank, with_prior_preservation, class_prompt, class_images],
984
+ queue=False
985
+ )
986
+ images.change(
987
+ check_removed_and_restart,
988
+ inputs=[images],
989
+ outputs=[captioning_area, advanced, cost_estimation, start],
990
+ queue=False
991
+ )
992
+ training_option.change(
993
+ make_options_visible,
994
+ inputs=training_option,
995
+ outputs=[concept_sentence, image_upload],
996
+ queue=False
997
+ )
998
+ max_train_steps.change(
999
+ calculate_price,
1000
+ inputs=[max_train_steps, with_prior_preservation],
1001
+ outputs=[training_cost_estimate],
1002
+ queue=False
1003
+ )
1004
+ start.click(
1005
+ fn=create_dataset,
1006
+ inputs=[images] + caption_list,
1007
+ outputs=dataset_folder,
1008
+ queue=False
1009
+ ).then(
1010
+ fn=start_training if is_spaces else start_training_og,
1011
+ inputs=[
1012
+ lora_name,
1013
+ training_option,
1014
+ concept_sentence,
1015
+ optimizer,
1016
+ use_snr_gamma,
1017
+ snr_gamma,
1018
+ mixed_precision,
1019
+ learning_rate,
1020
+ train_batch_size,
1021
+ max_train_steps,
1022
+ lora_rank,
1023
+ repeats,
1024
+ with_prior_preservation,
1025
+ class_prompt,
1026
+ class_images,
1027
+ num_class_images,
1028
+ train_text_encoder_ti,
1029
+ train_text_encoder_ti_frac,
1030
+ num_new_tokens_per_abstraction,
1031
+ train_text_encoder,
1032
+ train_text_encoder_frac,
1033
+ text_encoder_learning_rate,
1034
+ seed,
1035
+ resolution,
1036
+ num_train_epochs,
1037
+ checkpointing_steps,
1038
+ prior_loss_weight,
1039
+ gradient_accumulation_steps,
1040
+ gradient_checkpointing,
1041
+ enable_xformers_memory_efficient_attention,
1042
+ adam_beta1,
1043
+ adam_beta2,
1044
+ use_prodigy_beta3,
1045
+ prodigy_beta3,
1046
+ prodigy_decouple,
1047
+ adam_weight_decay,
1048
+ use_adam_weight_decay_text_encoder,
1049
+ adam_weight_decay_text_encoder,
1050
+ adam_epsilon,
1051
+ prodigy_use_bias_correction,
1052
+ prodigy_safeguard_warmup,
1053
+ max_grad_norm,
1054
+ scale_lr,
1055
+ lr_num_cycles,
1056
+ lr_scheduler,
1057
+ lr_power,
1058
+ lr_warmup_steps,
1059
+ dataloader_num_workers,
1060
+ local_rank,
1061
+ dataset_folder,
1062
+ token
1063
+ ],
1064
+ outputs = progress_area,
1065
+ queue=False
1066
+ )
1067
+
1068
+ do_captioning.click(
1069
+ fn=run_captioning, inputs=[images] + caption_list + [training_option], outputs=caption_list
1070
+ )
1071
+ #demo.load(fn=swap_opacity, outputs=[main_ui], queue=False, concurrency_limit=50)
1072
+ if __name__ == "__main__":
1073
+ demo.queue()
1074
+ demo.launch(share=True)
requirements_local.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ python-slugify
3
+ uuid
4
+ peft==0.7.1
5
+ huggingface-hub==0.23.4
6
+ diffusers==0.29.2
7
+ transformers==4.42.3
8
+ accelerate==0.31.0
9
+ safetensors==0.4.3
10
+ prodigyopt==1.0
11
+ hf-transfer==0.1.4
12
+ datasets==2.20.0
13
+ spaces
14
+ gradio