Spaces:
Runtime error
Runtime error
Commit
·
ef553b2
1
Parent(s):
1aabba7
Update app_training.py
Browse files- app_training.py +30 -23
app_training.py
CHANGED
@@ -13,6 +13,7 @@ from trainer import Trainer
|
|
13 |
|
14 |
def create_training_demo(trainer: Trainer,
|
15 |
pipe: InferencePipeline | None = None) -> gr.Blocks:
|
|
|
16 |
with gr.Blocks() as demo:
|
17 |
with gr.Row():
|
18 |
with gr.Column():
|
@@ -32,7 +33,8 @@ def create_training_demo(trainer: Trainer,
|
|
32 |
max_lines=1)
|
33 |
delete_existing_model = gr.Checkbox(
|
34 |
label='Delete existing model of the same name',
|
35 |
-
value=False
|
|
|
36 |
validation_prompt = gr.Text(label='Validation Prompt')
|
37 |
with gr.Box():
|
38 |
gr.Markdown('Upload Settings')
|
@@ -63,31 +65,35 @@ def create_training_demo(trainer: Trainer,
|
|
63 |
value='512',
|
64 |
label='Resolution',
|
65 |
visible=False)
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
label='
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
gr.Markdown('''
|
88 |
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
|
89 |
- It takes a few minutes to download the base model first.
|
90 |
-
- Expected time to train a model for 300 steps: 20 minutes with T4
|
91 |
- It takes a few minutes to upload your trained model.
|
92 |
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
93 |
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
|
@@ -128,6 +134,7 @@ def create_training_demo(trainer: Trainer,
|
|
128 |
delete_existing_repo,
|
129 |
upload_to,
|
130 |
remove_gpu_after_training,
|
|
|
131 |
],
|
132 |
outputs=output_message)
|
133 |
return demo
|
|
|
13 |
|
14 |
def create_training_demo(trainer: Trainer,
|
15 |
pipe: InferencePipeline | None = None) -> gr.Blocks:
|
16 |
+
hf_token = os.getenv('HF_TOKEN')
|
17 |
with gr.Blocks() as demo:
|
18 |
with gr.Row():
|
19 |
with gr.Column():
|
|
|
33 |
max_lines=1)
|
34 |
delete_existing_model = gr.Checkbox(
|
35 |
label='Delete existing model of the same name',
|
36 |
+
value=False,
|
37 |
+
visible=False)
|
38 |
validation_prompt = gr.Text(label='Validation Prompt')
|
39 |
with gr.Box():
|
40 |
gr.Markdown('Upload Settings')
|
|
|
65 |
value='512',
|
66 |
label='Resolution',
|
67 |
visible=False)
|
68 |
+
|
69 |
+
token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=True if hf_token else False)
|
70 |
+
with gr.Accordion("Advanced settings", open=False):
|
71 |
+
num_training_steps = gr.Number(
|
72 |
+
label='Number of Training Steps', value=300, precision=0)
|
73 |
+
learning_rate = gr.Number(label='Learning Rate',
|
74 |
+
value=0.000035)
|
75 |
+
gradient_accumulation = gr.Number(
|
76 |
+
label='Number of Gradient Accumulation',
|
77 |
+
value=1,
|
78 |
+
precision=0)
|
79 |
+
seed = gr.Slider(label='Seed',
|
80 |
+
minimum=0,
|
81 |
+
maximum=100000,
|
82 |
+
step=1,
|
83 |
+
randomize=True,
|
84 |
+
value=0)
|
85 |
+
fp16 = gr.Checkbox(label='FP16', value=True)
|
86 |
+
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
|
87 |
+
checkpointing_steps = gr.Number(label='Checkpointing Steps',
|
88 |
+
value=1000,
|
89 |
+
precision=0)
|
90 |
+
validation_epochs = gr.Number(label='Validation Epochs',
|
91 |
+
value=100,
|
92 |
+
precision=0)
|
93 |
gr.Markdown('''
|
94 |
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
|
95 |
- It takes a few minutes to download the base model first.
|
96 |
+
- Expected time to train a model for 300 steps: ~20 minutes with T4
|
97 |
- It takes a few minutes to upload your trained model.
|
98 |
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
99 |
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
|
|
|
134 |
delete_existing_repo,
|
135 |
upload_to,
|
136 |
remove_gpu_after_training,
|
137 |
+
token
|
138 |
],
|
139 |
outputs=output_message)
|
140 |
return demo
|