Spaces:
Build error
Build error
Update
Browse files- app.py +14 -2
- trainer.py +13 -12
app.py
CHANGED
@@ -73,7 +73,6 @@ def create_training_demo(trainer: Trainer,
|
|
73 |
gr.Markdown('Training Data')
|
74 |
concept_images = gr.Files(label='Images for your concept')
|
75 |
concept_prompt = gr.Textbox(label='Concept Prompt',
|
76 |
-
value='sks',
|
77 |
max_lines=1)
|
78 |
gr.Markdown('''
|
79 |
- Upload images of the style you are planning on training on.
|
@@ -84,8 +83,14 @@ def create_training_demo(trainer: Trainer,
|
|
84 |
num_training_steps = gr.Number(
|
85 |
label='Number of Training Steps', value=1000, precision=0)
|
86 |
learning_rate = gr.Number(label='Learning Rate', value=0.0001)
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
gr.Markdown('''
|
88 |
-
- It will take about
|
89 |
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
90 |
''')
|
91 |
|
@@ -108,6 +113,9 @@ def create_training_demo(trainer: Trainer,
|
|
108 |
concept_prompt,
|
109 |
num_training_steps,
|
110 |
learning_rate,
|
|
|
|
|
|
|
111 |
],
|
112 |
outputs=[
|
113 |
training_status,
|
@@ -175,6 +183,10 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
|
|
175 |
value=7)
|
176 |
|
177 |
run_button = gr.Button('Generate')
|
|
|
|
|
|
|
|
|
178 |
with gr.Column():
|
179 |
result = gr.Image(label='Result')
|
180 |
|
|
|
73 |
gr.Markdown('Training Data')
|
74 |
concept_images = gr.Files(label='Images for your concept')
|
75 |
concept_prompt = gr.Textbox(label='Concept Prompt',
|
|
|
76 |
max_lines=1)
|
77 |
gr.Markdown('''
|
78 |
- Upload images of the style you are planning on training on.
|
|
|
83 |
num_training_steps = gr.Number(
|
84 |
label='Number of Training Steps', value=1000, precision=0)
|
85 |
learning_rate = gr.Number(label='Learning Rate', value=0.0001)
|
86 |
+
gradient_accumulation = gr.Number(
|
87 |
+
label='Number of Gradient Accumulation',
|
88 |
+
value=1,
|
89 |
+
precision=0)
|
90 |
+
fp16 = gr.Checkbox(label='FP16', value=True)
|
91 |
+
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
|
92 |
gr.Markdown('''
|
93 |
+
- It will take about 8 minutes to train for 1000 steps with a T4 GPU.
|
94 |
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
95 |
''')
|
96 |
|
|
|
113 |
concept_prompt,
|
114 |
num_training_steps,
|
115 |
learning_rate,
|
116 |
+
gradient_accumulation,
|
117 |
+
fp16,
|
118 |
+
use_8bit_adam,
|
119 |
],
|
120 |
outputs=[
|
121 |
training_status,
|
|
|
183 |
value=7)
|
184 |
|
185 |
run_button = gr.Button('Generate')
|
186 |
+
|
187 |
+
gr.Markdown('''
|
188 |
+
- The pretrained models are trained with the concept prompt "style of sks".
|
189 |
+
''')
|
190 |
with gr.Column():
|
191 |
result = gr.Image(label='Result')
|
192 |
|
trainer.py
CHANGED
@@ -54,15 +54,10 @@ class Trainer:
|
|
54 |
out_path = self.instance_data_dir / f'{i:03d}.jpg'
|
55 |
image.save(out_path, format='JPEG', quality=100)
|
56 |
|
57 |
-
def run(
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
concept_images: list | None,
|
62 |
-
concept_prompt: str,
|
63 |
-
n_steps: int,
|
64 |
-
learning_rate: float,
|
65 |
-
) -> tuple[dict, str]:
|
66 |
if not torch.cuda.is_available():
|
67 |
raise gr.Error('CUDA is not available.')
|
68 |
|
@@ -80,24 +75,30 @@ class Trainer:
|
|
80 |
self.cleanup_dirs()
|
81 |
self.prepare_dataset(concept_images, resolution)
|
82 |
|
83 |
-
self.is_running = True
|
84 |
command = f'''
|
85 |
accelerate launch lora/train_lora_dreambooth.py \
|
86 |
--pretrained_model_name_or_path={base_model} \
|
87 |
--instance_data_dir={self.instance_data_dir} \
|
88 |
--output_dir={self.output_dir} \
|
89 |
-
--instance_prompt="
|
90 |
--resolution={resolution} \
|
91 |
--train_batch_size=1 \
|
92 |
-
--gradient_accumulation_steps=
|
93 |
--learning_rate={learning_rate} \
|
94 |
--lr_scheduler=constant \
|
95 |
--lr_warmup_steps=0 \
|
96 |
--max_train_steps={n_steps}
|
97 |
'''
|
|
|
|
|
|
|
|
|
|
|
98 |
with open(self.output_dir / 'train.sh', 'w') as f:
|
99 |
command_s = ' '.join(command.split())
|
100 |
f.write(command_s)
|
|
|
|
|
101 |
res = subprocess.run(shlex.split(command))
|
102 |
self.is_running = False
|
103 |
|
|
|
54 |
out_path = self.instance_data_dir / f'{i:03d}.jpg'
|
55 |
image.save(out_path, format='JPEG', quality=100)
|
56 |
|
57 |
+
def run(self, base_model: str, resolution_s: str,
|
58 |
+
concept_images: list | None, concept_prompt: str, n_steps: int,
|
59 |
+
learning_rate: float, gradient_accumulation: int, fp16: bool,
|
60 |
+
use_8bit_adam: bool) -> tuple[dict, str]:
|
|
|
|
|
|
|
|
|
|
|
61 |
if not torch.cuda.is_available():
|
62 |
raise gr.Error('CUDA is not available.')
|
63 |
|
|
|
75 |
self.cleanup_dirs()
|
76 |
self.prepare_dataset(concept_images, resolution)
|
77 |
|
|
|
78 |
command = f'''
|
79 |
accelerate launch lora/train_lora_dreambooth.py \
|
80 |
--pretrained_model_name_or_path={base_model} \
|
81 |
--instance_data_dir={self.instance_data_dir} \
|
82 |
--output_dir={self.output_dir} \
|
83 |
+
--instance_prompt="{concept_prompt}" \
|
84 |
--resolution={resolution} \
|
85 |
--train_batch_size=1 \
|
86 |
+
--gradient_accumulation_steps={gradient_accumulation} \
|
87 |
--learning_rate={learning_rate} \
|
88 |
--lr_scheduler=constant \
|
89 |
--lr_warmup_steps=0 \
|
90 |
--max_train_steps={n_steps}
|
91 |
'''
|
92 |
+
if fp16:
|
93 |
+
command += ' --mixed_precision fp16 '
|
94 |
+
if use_8bit_adam:
|
95 |
+
command += ' --use_8bit_adam'
|
96 |
+
|
97 |
with open(self.output_dir / 'train.sh', 'w') as f:
|
98 |
command_s = ' '.join(command.split())
|
99 |
f.write(command_s)
|
100 |
+
|
101 |
+
self.is_running = True
|
102 |
res = subprocess.run(shlex.split(command))
|
103 |
self.is_running = False
|
104 |
|