fffiloni commited on
Commit
05fd390
1 Parent(s): 7ac9941

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ from huggingface_hub import snapshot_download
4
+
5
+ def set_accelerate_default_config():
6
+ try:
7
+ subprocess.run(["accelerate", "config", "default"], check=True)
8
+ print("Accelerate default config set successfully!")
9
+ except subprocess.CalledProcessError as e:
10
+ print(f"An error occurred: {e}")
11
+
12
+ def train_dreambooth_lora_sdxl(instance_data_dir):
13
+
14
+ script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
15
+
16
+ command = [
17
+ "accelerate",
18
+ "launch",
19
+ script_filename, # Use the local script
20
+ "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
21
+ "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
22
+ f"--instance_data_dir={instance_data_dir}",
23
+ "--output_dir=lora-trained-xl-colab",
24
+ "--mixed_precision=fp16",
25
+ "--instance_prompt=egnestl",
26
+ "--resolution=1024",
27
+ "--train_batch_size=2",
28
+ "--gradient_accumulation_steps=2",
29
+ "--gradient_checkpointing",
30
+ "--learning_rate=1e-4",
31
+ "--lr_scheduler=constant",
32
+ "--lr_warmup_steps=0",
33
+ "--enable_xformers_memory_efficient_attention",
34
+ "--mixed_precision=fp16",
35
+ "--use_8bit_adam",
36
+ "--max_train_steps=25",
37
+ "--checkpointing_steps=717",
38
+ "--seed=0",
39
+ "--push_to_hub"
40
+ ]
41
+
42
+ try:
43
+ subprocess.run(command, check=True)
44
+ print("Training is finished!")
45
+ except subprocess.CalledProcessError as e:
46
+ print(f"An error occurred: {e}")
47
+
48
+ def main(dataset_url):
49
+
50
+ dataset_repo = dataset_url
51
+
52
+ # Automatically set local_dir based on the last part of dataset_repo
53
+ repo_parts = dataset_repo.split("/")
54
+ local_dir = f"./{repo_parts[-1]}" # Use the last part of the split
55
+
56
+ gr.Info("Downloading dataset ...")
57
+
58
+ snapshot_download(
59
+ dataset_repo,
60
+ local_dir=local_dir,
61
+ repo_type="dataset",
62
+ ignore_patterns=".gitattributes",
63
+ )
64
+
65
+ set_accelerate_default_config()
66
+
67
+ gr.Info("Training begins ...")
68
+ train_dreambooth_lora_sdxl(instance_data_dir=repo_parts[-1])
69
+
70
+ return "Done"
71
+
72
+ with gr.Blocks() as demo:
73
+ with gr.Column():
74
+ dataset_id = gr.Textbox(label="Dataset ID")
75
+ train_button = gr.Button("Train !")
76
+ status = gr.Textbox(labe="Training status")
77
+
78
+ train_button.click(
79
+ fn = main,
80
+ inputs = [dataset_id],
81
+ outputs = [status]
82
+ )
83
+
84
+ demo.queue().launch()