multimodalart HF staff commited on
Commit
53c1e6e
1 Parent(s): 6a2b290

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +819 -0
app.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import requests
4
+ import subprocess
5
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
6
+ import torch
7
+ import uuid
8
+ import os
9
+ import shutil
10
+ import json
11
+ import random
12
+ from slugify import slugify
13
+ import argparse
14
+ import importlib
15
+ import sys
16
+ MAX_IMAGES = 50
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
21
+ model = Blip2ForConditionalGeneration.from_pretrained(
22
+ "Salesforce/blip2-opt-2.7b", device_map={"": 0}, torch_dtype=torch.float16
23
+ )
24
+ #Run first captioning as apparently makes the other ones faster
25
+ pil_image = Image.new('RGB', (512, 512), 'black')
26
+ blip_inputs = processor(images=pil_image, return_tensors="pt").to(device, torch.float16)
27
+ generated_ids = model.generate(**blip_inputs)
28
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
29
+
30
+ def load_captioning(uploaded_images, option):
31
+ updates = []
32
+ if len(uploaded_images) > MAX_IMAGES:
33
+ raise gr.Error(
34
+ f"Error: for now, only {MAX_IMAGES} or less images are allowed for training"
35
+ )
36
+ # Update for the captioning_area
37
+ for _ in range(3):
38
+ updates.append(gr.update(visible=True))
39
+ # Update visibility and image for each captioning row and image
40
+ for i in range(1, MAX_IMAGES + 1):
41
+ # Determine if the current row and image should be visible
42
+ visible = i <= len(uploaded_images)
43
+
44
+ # Update visibility of the captioning row
45
+ updates.append(gr.update(visible=visible))
46
+
47
+ # Update for image component - display image if available, otherwise hide
48
+ image_value = uploaded_images[i - 1] if visible else None
49
+ updates.append(gr.update(value=image_value, visible=visible))
50
+
51
+ text_value = option if visible else None
52
+ updates.append(gr.update(value=text_value, visible=visible))
53
+ return updates
54
+
55
+ def check_removed_and_restart(images):
56
+ visible = bool(images)
57
+ return [gr.update(visible=visible) for _ in range(3)]
58
+
59
+ def make_options_visible(option):
60
+ if (option == "object") or (option == "face"):
61
+ sentence = "A photo of TOK"
62
+ elif option == "style":
63
+ sentence = "in the style of TOK"
64
+ elif option == "custom":
65
+ sentence = "TOK"
66
+ return (
67
+ gr.update(value=sentence, visible=True),
68
+ gr.update(visible=True),
69
+ )
70
+ def change_defaults(option, images):
71
+ num_images = len(images)
72
+ max_train_steps = num_images*150
73
+ max_train_steps = 500 if max_train_steps < 500 else max_train_steps
74
+ random_files = []
75
+ with_prior_preservation = False
76
+ class_prompt = ""
77
+ if(num_images > 24):
78
+ repeats = 1
79
+ elif(num_images > 10):
80
+ repeats = 2
81
+ else:
82
+ repeats = 3
83
+ if(max_train_steps > 2400):
84
+ max_train_steps = 2400
85
+
86
+ if(option == "face"):
87
+ rank = 64
88
+ max_train_steps = num_images*100
89
+ lr_scheduler = "constant"
90
+ #Takes 150 random faces for the prior preservation loss
91
+ directory = "faces"
92
+ file_count = 150
93
+ files = [os.path.join(directory, file) for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))]
94
+ random_files = random.sample(files, min(len(files), file_count))
95
+ with_prior_preservation = True
96
+ class_prompt = "a photo of a person"
97
+ elif(option == "style"):
98
+ rank = 16
99
+ lr_scheduler = "polynomial"
100
+ elif(option == "object"):
101
+ rank = 8
102
+ repeats = 1
103
+ lr_scheduler = "constant"
104
+ else:
105
+ rank = 32
106
+ lr_scheduler = "constant"
107
+
108
+ return max_train_steps, repeats, lr_scheduler, rank, with_prior_preservation, class_prompt, random_files
109
+
110
+ def create_dataset(*inputs):
111
+ images = inputs[0]
112
+ destination_folder = str(uuid.uuid4())
113
+ print(destination_folder)
114
+ if not os.path.exists(destination_folder):
115
+ os.makedirs(destination_folder)
116
+
117
+ jsonl_file_path = os.path.join(destination_folder, 'metadata.jsonl')
118
+ with open(jsonl_file_path, 'a') as jsonl_file:
119
+ for index, image in enumerate(images):
120
+ new_image_path = shutil.copy(image, destination_folder)
121
+
122
+ original_caption = inputs[index + 1]
123
+ file_name = os.path.basename(new_image_path)
124
+
125
+ data = {"file_name": file_name, "prompt": original_caption}
126
+
127
+ jsonl_file.write(json.dumps(data) + "\n")
128
+
129
+ return destination_folder
130
+
131
+ def start_training(
132
+ lora_name,
133
+ training_option,
134
+ concept_sentence,
135
+ optimizer,
136
+ use_snr_gamma,
137
+ snr_gamma,
138
+ mixed_precision,
139
+ learning_rate,
140
+ train_batch_size,
141
+ max_train_steps,
142
+ lora_rank,
143
+ repeats,
144
+ with_prior_preservation,
145
+ class_prompt,
146
+ class_images,
147
+ num_class_images,
148
+ train_text_encoder_ti,
149
+ train_text_encoder_ti_frac,
150
+ num_new_tokens_per_abstraction,
151
+ train_text_encoder,
152
+ train_text_encoder_frac,
153
+ text_encoder_learning_rate,
154
+ seed,
155
+ resolution,
156
+ num_train_epochs,
157
+ checkpointing_steps,
158
+ prior_loss_weight,
159
+ gradient_accumulation_steps,
160
+ gradient_checkpointing,
161
+ enable_xformers_memory_efficient_attention,
162
+ adam_beta1,
163
+ adam_beta2,
164
+ prodigy_beta3,
165
+ prodigy_decouple,
166
+ adam_weight_decay,
167
+ adam_weight_decay_text_encoder,
168
+ adam_epsilon,
169
+ prodigy_use_bias_correction,
170
+ prodigy_safeguard_warmup,
171
+ max_grad_norm,
172
+ scale_lr,
173
+ lr_num_cycles,
174
+ lr_scheduler,
175
+ lr_power,
176
+ lr_warmup_steps,
177
+ dataloader_num_workers,
178
+ local_rank,
179
+ dataset_folder,
180
+ token,
181
+ progress = gr.Progress(track_tqdm=True)
182
+ ):
183
+ slugged_lora_name = slugify(lora_name)
184
+ spacerunner_folder = str(uuid.uuid4())
185
+ commands = [
186
+ "pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
187
+ "pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
188
+ f"instance_prompt={concept_sentence}",
189
+ f"dataset_name=./{dataset_folder}",
190
+ "caption_column=prompt",
191
+ f"output_dir={slugged_lora_name}",
192
+ f"mixed_precision={mixed_precision}",
193
+ f"resolution={int(resolution)}",
194
+ f"train_batch_size={int(train_batch_size)}",
195
+ f"repeats={int(repeats)}",
196
+ f"gradient_accumulation_steps={int(gradient_accumulation_steps)}",
197
+ f"learning_rate={learning_rate}",
198
+ f"text_encoder_lr={text_encoder_learning_rate}",
199
+ f"adam_beta1={adam_beta1}",
200
+ f"adam_beta2={adam_beta2}",
201
+ f"optimizer={'adamW' if optimizer == '8bitadam' else optimizer}",
202
+ f"train_text_encoder_ti_frac={train_text_encoder_ti_frac}",
203
+ f"lr_scheduler={lr_scheduler}",
204
+ f"lr_warmup_steps={int(lr_warmup_steps)}",
205
+ f"rank={int(lora_rank)}",
206
+ f"max_train_steps={int(max_train_steps)}",
207
+ f"checkpointing_steps={int(checkpointing_steps)}",
208
+ f"seed={int(seed)}",
209
+ f"prior_loss_weight={prior_loss_weight}",
210
+ f"num_new_tokens_per_abstraction={int(num_new_tokens_per_abstraction)}",
211
+ f"num_train_epochs={int(num_train_epochs)}",
212
+ f"prodigy_beta3={prodigy_beta3}",
213
+ f"adam_weight_decay={adam_weight_decay}",
214
+ f"adam_weight_decay_text_encoder={adam_weight_decay_text_encoder}",
215
+ f"adam_epsilon={adam_epsilon}",
216
+ f"prodigy_decouple={prodigy_decouple}",
217
+ f"prodigy_use_bias_correction={prodigy_use_bias_correction}",
218
+ f"prodigy_safeguard_warmup={prodigy_safeguard_warmup}",
219
+ f"max_grad_norm={max_grad_norm}",
220
+ f"lr_num_cycles={int(lr_num_cycles)}",
221
+ f"lr_power={lr_power}",
222
+ f"dataloader_num_workers={int(dataloader_num_workers)}",
223
+ f"local_rank={int(local_rank)}",
224
+ "cache_latents",
225
+ "push_to_hub",
226
+ ]
227
+ slugged_lora_name
228
+ # Adding optional flags
229
+ if optimizer == "8bitadam":
230
+ commands.append("use_8bit_adam")
231
+ if gradient_checkpointing:
232
+ commands.append("gradient_checkpointing")
233
+
234
+ if train_text_encoder_ti:
235
+ commands.append("train_text_encoder_ti")
236
+ elif train_text_encoder:
237
+ commands.append("train_text_encoder")
238
+ commands.append(f"--train_text_encoder_frac={train_text_encoder_frac}")
239
+ if enable_xformers_memory_efficient_attention:
240
+ commands.append("enable_xformers_memory_efficient_attention")
241
+ if use_snr_gamma:
242
+ commands.append(f"snr_gamma={snr_gamma}")
243
+ if scale_lr:
244
+ commands.append("scale_lr")
245
+ if with_prior_preservation:
246
+ commands.append("with_prior_preservation")
247
+ commands.append(f"class_prompt={class_prompt}")
248
+ commands.append(f"num_class_images={int(num_class_images)}")
249
+ if class_images:
250
+ class_folder = str(uuid.uuid4())
251
+ if not os.path.exists(class_folder):
252
+ os.makedirs(class_folder)
253
+ for image in class_images:
254
+ shutil.copy(image, class_folder)
255
+ commands.append(f"class_data_dir={class_folder}")
256
+ shutil.copytree(class_folder, f"{spacerunner_folder}/{class_folder}")
257
+ # Joining the commands with ';' separator for spacerunner format
258
+ spacerunner_args = ';'.join(commands)
259
+ if not os.path.exists(spacerunner_folder):
260
+ os.makedirs(spacerunner_folder)
261
+ shutil.copy("train_dreambooth_lora_sdxl_advanced.py", f"{spacerunner_folder}/script.py")
262
+ shutil.copytree(dataset_folder, f"{spacerunner_folder}/{dataset_folder}")
263
+ requirements='''-peft
264
+ torch
265
+ git+https://github.com/huggingface/diffusers@c05d71be04345b18a5120542c363f6e4a3f99b05
266
+ transformers
267
+ accelerate
268
+ safetensors
269
+ prodigyopt
270
+ hf-transfer
271
+ git+https://github.com/huggingface/datasets.git'''
272
+ file_path = f'{spacerunner_folder}/requirements.txt'
273
+ with open(file_path, 'w') as file:
274
+ file.write(requirements)
275
+ # The subprocess call for autotrain spacerunner
276
+ api = HfApi(token=token)
277
+ username = api.whoami()["name"]
278
+ subprocess_command = ["autotrain", "spacerunner", "--project-name", slugged_lora_name, "--script-path", spacerunner_folder, "--username", username, "--token", token, "--backend", "spaces-a10gl", "--env","HF_TOKEN=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;HF_HUB_ENABLE_HF_TRANSFER=1", "--args", spacerunner_args]
279
+ print(subprocess_command)
280
+ subprocess.run(subprocess_command)
281
+ return f"Your training has started. Run over to <a href='https://huggingface.co/spaces/{username}/slugged_lora_name'>{username}/slugged_lora_name</a> to check the status (click the logs tab)"
282
+
283
+ def start_training_og(
284
+ lora_name,
285
+ training_option,
286
+ concept_sentence,
287
+ optimizer,
288
+ use_snr_gamma,
289
+ snr_gamma,
290
+ mixed_precision,
291
+ learning_rate,
292
+ train_batch_size,
293
+ max_train_steps,
294
+ lora_rank,
295
+ repeats,
296
+ with_prior_preservation,
297
+ class_prompt,
298
+ class_images,
299
+ num_class_images,
300
+ train_text_encoder_ti,
301
+ train_text_encoder_ti_frac,
302
+ num_new_tokens_per_abstraction,
303
+ train_text_encoder,
304
+ train_text_encoder_frac,
305
+ text_encoder_learning_rate,
306
+ seed,
307
+ resolution,
308
+ num_train_epochs,
309
+ checkpointing_steps,
310
+ prior_loss_weight,
311
+ gradient_accumulation_steps,
312
+ gradient_checkpointing,
313
+ enable_xformers_memory_efficient_attention,
314
+ adam_beta1,
315
+ adam_beta2,
316
+ prodigy_beta3,
317
+ prodigy_decouple,
318
+ adam_weight_decay,
319
+ adam_weight_decay_text_encoder,
320
+ adam_epsilon,
321
+ prodigy_use_bias_correction,
322
+ prodigy_safeguard_warmup,
323
+ max_grad_norm,
324
+ scale_lr,
325
+ lr_num_cycles,
326
+ lr_scheduler,
327
+ lr_power,
328
+ lr_warmup_steps,
329
+ dataloader_num_workers,
330
+ local_rank,
331
+ dataset_folder,
332
+ progress = gr.Progress(track_tqdm=True)
333
+ ):
334
+ slugged_lora_name = slugify(lora_name)
335
+ print(train_text_encoder_ti_frac)
336
+ commands = ["--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
337
+ "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
338
+ f"--instance_prompt={concept_sentence}",
339
+ f"--dataset_name=./{dataset_folder}",
340
+ "--caption_column=prompt",
341
+ f"--output_dir={slugged_lora_name}",
342
+ f"--mixed_precision={mixed_precision}",
343
+ f"--resolution={int(resolution)}",
344
+ f"--train_batch_size={int(train_batch_size)}",
345
+ f"--repeats={int(repeats)}",
346
+ f"--gradient_accumulation_steps={int(gradient_accumulation_steps)}",
347
+ f"--learning_rate={learning_rate}",
348
+ f"--text_encoder_lr={text_encoder_learning_rate}",
349
+ f"--adam_beta1={adam_beta1}",
350
+ f"--adam_beta2={adam_beta2}",
351
+ f"--optimizer={'adamW' if optimizer == '8bitadam' else optimizer}",
352
+ f"--train_text_encoder_ti_frac={train_text_encoder_ti_frac}",
353
+ f"--lr_scheduler={lr_scheduler}",
354
+ f"--lr_warmup_steps={int(lr_warmup_steps)}",
355
+ f"--rank={int(lora_rank)}",
356
+ f"--max_train_steps={int(max_train_steps)}",
357
+ f"--checkpointing_steps={int(checkpointing_steps)}",
358
+ f"--seed={int(seed)}",
359
+ f"--prior_loss_weight={prior_loss_weight}",
360
+ f"--num_new_tokens_per_abstraction={int(num_new_tokens_per_abstraction)}",
361
+ f"--num_train_epochs={int(num_train_epochs)}",
362
+ f"--prodigy_beta3={prodigy_beta3}",
363
+ f"--adam_weight_decay={adam_weight_decay}",
364
+ f"--adam_weight_decay_text_encoder={adam_weight_decay_text_encoder}",
365
+ f"--adam_epsilon={adam_epsilon}",
366
+ f"--prodigy_decouple={prodigy_decouple}",
367
+ f"--prodigy_use_bias_correction={prodigy_use_bias_correction}",
368
+ f"--prodigy_safeguard_warmup={prodigy_safeguard_warmup}",
369
+ f"--max_grad_norm={max_grad_norm}",
370
+ f"--lr_num_cycles={int(lr_num_cycles)}",
371
+ f"--lr_power={lr_power}",
372
+ f"--dataloader_num_workers={int(dataloader_num_workers)}",
373
+ f"--local_rank={int(local_rank)}",
374
+ "--cache_latents"
375
+ ]
376
+ if optimizer == "8bitadam":
377
+ commands.append("--use_8bit_adam")
378
+ if gradient_checkpointing:
379
+ commands.append("--gradient_checkpointing")
380
+
381
+ if train_text_encoder_ti:
382
+ commands.append("--train_text_encoder_ti")
383
+ elif train_text_encoder:
384
+ commands.append("--train_text_encoder")
385
+ commands.append(f"--train_text_encoder_frac={train_text_encoder_frac}")
386
+ if enable_xformers_memory_efficient_attention:
387
+ commands.append("--enable_xformers_memory_efficient_attention")
388
+ if use_snr_gamma:
389
+ commands.append(f"--snr_gamma={snr_gamma}")
390
+ if scale_lr:
391
+ commands.append("--scale_lr")
392
+ if with_prior_preservation:
393
+ commands.append(f"--with_prior_preservation")
394
+ commands.append(f"--class_prompt={class_prompt}")
395
+ commands.append(f"--num_class_images={int(num_class_images)}")
396
+ if(class_images):
397
+ class_folder = str(uuid.uuid4())
398
+ if not os.path.exists(class_folder):
399
+ os.makedirs(class_folder)
400
+ for image in class_images:
401
+ shutil.copy(image, class_folder)
402
+ commands.append(f"--class_data_dir={class_folder}")
403
+
404
+ print(commands)
405
+ from train_dreambooth_lora_sdxl_advanced import main as train_main, parse_args as parse_train_args
406
+ args = parse_train_args(commands)
407
+ train_main(args)
408
+ #print(commands)
409
+ #subprocess.run(commands)
410
+ return "ok!"
411
+
412
+ def run_captioning(*inputs):
413
+ print(inputs)
414
+ images = inputs[0]
415
+ training_option = inputs[-1]
416
+ print(training_option)
417
+ final_captions = [""] * MAX_IMAGES
418
+ for index, image in enumerate(images):
419
+ original_caption = inputs[index + 1]
420
+ pil_image = Image.open(image)
421
+ blip_inputs = processor(images=pil_image, return_tensors="pt").to(device, torch.float16)
422
+ generated_ids = model.generate(**blip_inputs)
423
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
424
+ if training_option == "style":
425
+ final_caption = generated_text + " " + original_caption
426
+ else:
427
+ final_caption = original_caption + " " + generated_text
428
+ final_captions[index] = final_caption
429
+ yield final_captions
430
+
431
+ with gr.Blocks() as demo:
432
+ dataset_folder = gr.State()
433
+ gr.Markdown("# SDXL LoRA Dreambooth Training")
434
+ lora_name = gr.Textbox(label="The name of your LoRA", placeholder="e.g.: Persian Miniature Painting style, Cat Toy")
435
+ training_option = gr.Radio(
436
+ label="What are you training?", choices=["object", "style", "face", "custom"]
437
+ )
438
+ concept_sentence = gr.Textbox(
439
+ label="Concept sentence",
440
+ info="A common sentence to be used in all images as your captioning structure. TOK is a special mandatory token that will be used to teach the model your concept.",
441
+ placeholder="e.g.: A photo of TOK, in the style of TOK",
442
+ visible=False,
443
+ interactive=True,
444
+ )
445
+ with gr.Group(visible=False) as image_upload:
446
+ with gr.Row():
447
+ images = gr.File(
448
+ file_types=["image"],
449
+ label="Upload your images",
450
+ file_count="multiple",
451
+ interactive=True,
452
+ visible=True,
453
+ scale=1,
454
+ )
455
+ with gr.Column(scale=3, visible=False) as captioning_area:
456
+ with gr.Column():
457
+ gr.Markdown(
458
+ """# Custom captioning
459
+ To improve the quality of your outputs, you can add a custom caption for each image, describing exactly what is taking place in each of them. Including TOK is mandatory. You can leave things as is if you don't want to include captioning.
460
+ """
461
+ )
462
+ do_captioning = gr.Button("Add AI captions with BLIP-2")
463
+ output_components = [captioning_area]
464
+ caption_list = []
465
+ for i in range(1, MAX_IMAGES + 1):
466
+ locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
467
+ with locals()[f"captioning_row_{i}"]:
468
+ locals()[f"image_{i}"] = gr.Image(
469
+ width=64,
470
+ height=64,
471
+ min_width=64,
472
+ interactive=False,
473
+ scale=1,
474
+ show_label=False,
475
+ )
476
+ locals()[f"caption_{i}"] = gr.Textbox(
477
+ label=f"Caption {i}", scale=4
478
+ )
479
+
480
+ output_components.append(locals()[f"captioning_row_{i}"])
481
+ output_components.append(locals()[f"image_{i}"])
482
+ output_components.append(locals()[f"caption_{i}"])
483
+ caption_list.append(locals()[f"caption_{i}"])
484
+ with gr.Accordion(open=False, label="Advanced options", visible=False) as advanced:
485
+ with gr.Row():
486
+ with gr.Column():
487
+ optimizer = gr.Dropdown(
488
+ label="Optimizer",
489
+ info="Prodigy is an auto-optimizer and works good by default. If you prefer to set your own learning rates, change it to AdamW. If you don't have enough VRAM to train with AdamW, pick 8-bit Adam.",
490
+ choices=[
491
+ ("Prodigy", "prodigy"),
492
+ ("AdamW", "adamW"),
493
+ ("8-bit Adam", "8bitadam"),
494
+ ],
495
+ value="prodigy",
496
+ interactive=True,
497
+ )
498
+ use_snr_gamma = gr.Checkbox(label="Use SNR Gamma")
499
+ snr_gamma = gr.Number(
500
+ label="snr_gamma",
501
+ info="SNR weighting gamma to re-balance the loss",
502
+ value=5.000,
503
+ step=0.1,
504
+ visible=False,
505
+ )
506
+ mixed_precision = gr.Dropdown(
507
+ label="Mixed Precision",
508
+ choices=["no", "fp16", "bf16"],
509
+ value="bf16",
510
+ )
511
+ learning_rate = gr.Number(
512
+ label="UNet Learning rate",
513
+ minimum=0.0,
514
+ maximum=10.0,
515
+ step=0.0000001,
516
+ value=1.0, # For prodigy you start high and it will optimize down
517
+ )
518
+ train_batch_size = gr.Number(label="Train batch size", value=2)
519
+ max_train_steps = gr.Number(
520
+ label="Max train steps", minimum=1, maximum=50000, value=1000
521
+ )
522
+ lora_rank = gr.Number(
523
+ label="LoRA Rank",
524
+ info="Rank for the Low Rank Adaptation (LoRA), a higher rank produces a larger LoRA",
525
+ value=8,
526
+ step=2,
527
+ minimum=2,
528
+ maximum=1024,
529
+ )
530
+ repeats = gr.Number(
531
+ label="Repeats",
532
+ info="How many times to repeat the training data.",
533
+ value=1,
534
+ minimum=1,
535
+ maximum=200,
536
+ )
537
+ with gr.Column():
538
+ with_prior_preservation = gr.Checkbox(
539
+ label="Prior preservation loss",
540
+ info="Prior preservation helps to ground the model to things that are similar to your concept. Good for faces.",
541
+ value=False,
542
+ )
543
+ with gr.Column(visible=False) as prior_preservation_params:
544
+ with gr.Tab("prompt"):
545
+ class_prompt = gr.Textbox(
546
+ label="Class Prompt",
547
+ info="The prompt that will be used to generate your class images",
548
+ )
549
+
550
+ with gr.Tab("images"):
551
+ class_images = gr.File(
552
+ file_types=["image"],
553
+ label="Upload your images",
554
+ file_count="multiple",
555
+ )
556
+ num_class_images = gr.Number(
557
+ label="Number of class images, if there are less images uploaded then the number you put here, additional images will be sampled with Class Prompt",
558
+ value=20,
559
+ )
560
+ train_text_encoder_ti = gr.Checkbox(
561
+ label="Do textual inversion",
562
+ value=True,
563
+ info="Will train a textual inversion embedding together with the LoRA. Increases quality significantly.",
564
+ )
565
+ with gr.Group(visible=True) as pivotal_tuning_params:
566
+ train_text_encoder_ti_frac = gr.Number(
567
+ label="Pivot Textual Inversion",
568
+ info="% of epochs to train textual inversion for",
569
+ value=0.5,
570
+ step=0.1,
571
+ )
572
+ num_new_tokens_per_abstraction = gr.Number(
573
+ label="Tokens to train",
574
+ info="Number of tokens to train in the textual inversion",
575
+ value=2,
576
+ minimum=1,
577
+ maximum=1024,
578
+ interactive=True,
579
+ )
580
+ with gr.Group(visible=False) as text_encoder_train_params:
581
+ train_text_encoder = gr.Checkbox(
582
+ label="Train Text Encoder", value=True
583
+ )
584
+ train_text_encoder_frac = gr.Number(
585
+ label="Pivot Text Encoder",
586
+ info="% of epochs to train the text encoder for",
587
+ value=0.8,
588
+ step=0.1,
589
+ )
590
+ text_encoder_learning_rate = gr.Number(
591
+ label="Text encoder learning rate",
592
+ minimum=0.0,
593
+ maximum=10.0,
594
+ step=0.0000001,
595
+ value=1.0,
596
+ )
597
+ seed = gr.Number(label="Seed", value=42)
598
+ resolution = gr.Number(
599
+ label="Resolution",
600
+ info="Only square sizes are supported for now, the value will be width and height",
601
+ value=1024,
602
+ )
603
+
604
+ with gr.Accordion(open=False, label="Even more advanced options"):
605
+ with gr.Row():
606
+ with gr.Column():
607
+ num_train_epochs = gr.Number(label="num_train_epochs", value=1)
608
+ checkpointing_steps = gr.Number(
609
+ label="checkpointing_steps", value=5000
610
+ )
611
+ prior_loss_weight = gr.Number(label="prior_loss_weight", value=1)
612
+ gradient_accumulation_steps = gr.Number(
613
+ label="gradient_accumulation_steps", value=1
614
+ )
615
+ gradient_checkpointing = gr.Checkbox(
616
+ label="gradient_checkpointing",
617
+ info="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass",
618
+ value=True,
619
+ )
620
+ enable_xformers_memory_efficient_attention = gr.Checkbox(
621
+ label="enable_xformers_memory_efficient_attention"
622
+ )
623
+ adam_beta1 = gr.Number(
624
+ label="adam_beta1", value=0.9, minimum=0, maximum=1, step=0.01
625
+ )
626
+ adam_beta2 = gr.Number(
627
+ label="adam_beta2", minimum=0, maximum=1, step=0.01, value=0.99
628
+ )
629
+ prodigy_beta3 = gr.Number(
630
+ label="Prodigy Beta 3",
631
+ value=None,
632
+ step=0.01,
633
+ minimum=0,
634
+ maximum=1,
635
+ )
636
+ prodigy_decouple = gr.Checkbox(label="Prodigy Decouple")
637
+ adam_weight_decay = gr.Number(
638
+ label="Adam Weight Decay",
639
+ value=1e-04,
640
+ step=0.00001,
641
+ minimum=0,
642
+ maximum=1,
643
+ )
644
+ adam_weight_decay_text_encoder = gr.Number(
645
+ label="Adam Weight Decay Text Encoder",
646
+ value=None,
647
+ step=0.00001,
648
+ minimum=0,
649
+ maximum=1,
650
+ )
651
+ adam_epsilon = gr.Number(
652
+ label="Adam Epsilon",
653
+ value=1e-08,
654
+ step=0.00000001,
655
+ minimum=0,
656
+ maximum=1,
657
+ )
658
+ prodigy_use_bias_correction = gr.Checkbox(
659
+ label="Prodigy Use Bias Correction", value=True
660
+ )
661
+ prodigy_safeguard_warmup = gr.Checkbox(
662
+ label="Prodigy Safeguard Warmup", value=True
663
+ )
664
+ max_grad_norm = gr.Number(
665
+ label="Max Grad Norm",
666
+ value=1.0,
667
+ minimum=0.1,
668
+ maximum=10,
669
+ step=0.1,
670
+ )
671
+ with gr.Column():
672
+ scale_lr = gr.Checkbox(
673
+ label="Scale learning rate",
674
+ info="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size",
675
+ )
676
+ lr_num_cycles = gr.Number(label="lr_num_cycles", value=1)
677
+ lr_scheduler = gr.Dropdown(
678
+ label="lr_scheduler",
679
+ choices=[
680
+ "linear",
681
+ "cosine",
682
+ "cosine_with_restarts",
683
+ "polynomial",
684
+ "constant",
685
+ "constant_with_warmup",
686
+ ],
687
+ value="constant",
688
+ )
689
+ lr_power = gr.Number(
690
+ label="lr_power", value=1.0, minimum=0.1, maximum=10
691
+ )
692
+ lr_warmup_steps = gr.Number(label="lr_warmup_steps", value=0)
693
+ dataloader_num_workers = gr.Number(
694
+ label="Dataloader num workers", value=0, minimum=0, maximum=64
695
+ )
696
+ local_rank = gr.Number(label="local_rank", value=-1)
697
+ token = gr.Textarea(label="Your Hugging Face write token", info="A Hugging Face write token you can obtain on the settings page.")
698
+ start = gr.Button("Start training", visible=False)
699
+ progress_area = gr.HTML()
700
+ output_components.insert(1, advanced)
701
+ output_components.insert(1, start)
702
+ use_snr_gamma.change(
703
+ lambda x: gr.update(visible=x),
704
+ inputs=use_snr_gamma,
705
+ outputs=snr_gamma,
706
+ queue=False,
707
+ )
708
+ with_prior_preservation.change(
709
+ lambda x: gr.update(visible=x),
710
+ inputs=with_prior_preservation,
711
+ outputs=prior_preservation_params,
712
+ queue=False,
713
+ )
714
+ train_text_encoder_ti.change(
715
+ lambda x: gr.update(visible=x),
716
+ inputs=train_text_encoder_ti,
717
+ outputs=pivotal_tuning_params,
718
+ queue=False,
719
+ ).then(
720
+ lambda x: gr.update(visible=(not x)),
721
+ inputs=train_text_encoder_ti,
722
+ outputs=text_encoder_train_params,
723
+ queue=False,
724
+ )
725
+ train_text_encoder.change(
726
+ lambda x: [gr.update(visible=x), gr.update(visible=x)],
727
+ inputs=train_text_encoder,
728
+ outputs=[train_text_encoder_frac, text_encoder_learning_rate],
729
+ queue=False,
730
+ )
731
+ class_images.change(
732
+ lambda x: gr.update(value=len(x)),
733
+ inputs=class_images,
734
+ outputs=num_class_images,
735
+ queue=False
736
+ )
737
+ images.upload(
738
+ load_captioning, inputs=[images, concept_sentence], outputs=output_components
739
+ ).then(
740
+ change_defaults,
741
+ inputs=[training_option, images],
742
+ outputs=[max_train_steps, repeats, lr_scheduler, lora_rank, with_prior_preservation, class_prompt, class_images]
743
+ )
744
+ images.change(
745
+ check_removed_and_restart,
746
+ inputs=[images],
747
+ outputs=[captioning_area, advanced, start],
748
+ )
749
+ training_option.change(
750
+ make_options_visible,
751
+ inputs=training_option,
752
+ outputs=[concept_sentence, image_upload],
753
+ )
754
+ start.click(
755
+ fn=create_dataset,
756
+ inputs=[images] + caption_list,
757
+ outputs=dataset_folder
758
+ ).then(
759
+ fn=start_training,
760
+ inputs=[
761
+ lora_name,
762
+ training_option,
763
+ concept_sentence,
764
+ optimizer,
765
+ use_snr_gamma,
766
+ snr_gamma,
767
+ mixed_precision,
768
+ learning_rate,
769
+ train_batch_size,
770
+ max_train_steps,
771
+ lora_rank,
772
+ repeats,
773
+ with_prior_preservation,
774
+ class_prompt,
775
+ class_images,
776
+ num_class_images,
777
+ train_text_encoder_ti,
778
+ train_text_encoder_ti_frac,
779
+ num_new_tokens_per_abstraction,
780
+ train_text_encoder,
781
+ train_text_encoder_frac,
782
+ text_encoder_learning_rate,
783
+ seed,
784
+ resolution,
785
+ num_train_epochs,
786
+ checkpointing_steps,
787
+ prior_loss_weight,
788
+ gradient_accumulation_steps,
789
+ gradient_checkpointing,
790
+ enable_xformers_memory_efficient_attention,
791
+ adam_beta1,
792
+ adam_beta2,
793
+ prodigy_beta3,
794
+ prodigy_decouple,
795
+ adam_weight_decay,
796
+ adam_weight_decay_text_encoder,
797
+ adam_epsilon,
798
+ prodigy_use_bias_correction,
799
+ prodigy_safeguard_warmup,
800
+ max_grad_norm,
801
+ scale_lr,
802
+ lr_num_cycles,
803
+ lr_scheduler,
804
+ lr_power,
805
+ lr_warmup_steps,
806
+ dataloader_num_workers,
807
+ local_rank,
808
+ dataset_folder,
809
+ token
810
+ ],
811
+ outputs = progress_area
812
+ )
813
+
814
+ do_captioning.click(
815
+ fn=run_captioning, inputs=[images] + caption_list + [training_option], outputs=caption_list
816
+ )
817
+ if __name__ == "__main__":
818
+ demo.queue()
819
+ demo.launch(share=True)