multimodalart HF staff commited on
Commit
ef553b2
1 Parent(s): 1aabba7

Update app_training.py

Browse files
Files changed (1) hide show
  1. app_training.py +30 -23
app_training.py CHANGED
@@ -13,6 +13,7 @@ from trainer import Trainer
13
 
14
  def create_training_demo(trainer: Trainer,
15
  pipe: InferencePipeline | None = None) -> gr.Blocks:
 
16
  with gr.Blocks() as demo:
17
  with gr.Row():
18
  with gr.Column():
@@ -32,7 +33,8 @@ def create_training_demo(trainer: Trainer,
32
  max_lines=1)
33
  delete_existing_model = gr.Checkbox(
34
  label='Delete existing model of the same name',
35
- value=False)
 
36
  validation_prompt = gr.Text(label='Validation Prompt')
37
  with gr.Box():
38
  gr.Markdown('Upload Settings')
@@ -63,31 +65,35 @@ def create_training_demo(trainer: Trainer,
63
  value='512',
64
  label='Resolution',
65
  visible=False)
66
- num_training_steps = gr.Number(
67
- label='Number of Training Steps', value=300, precision=0)
68
- learning_rate = gr.Number(label='Learning Rate',
69
- value=0.000035)
70
- gradient_accumulation = gr.Number(
71
- label='Number of Gradient Accumulation',
72
- value=1,
73
- precision=0)
74
- seed = gr.Slider(label='Seed',
75
- minimum=0,
76
- maximum=100000,
77
- step=1,
78
- value=0)
79
- fp16 = gr.Checkbox(label='FP16', value=True)
80
- use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
81
- checkpointing_steps = gr.Number(label='Checkpointing Steps',
82
- value=1000,
83
- precision=0)
84
- validation_epochs = gr.Number(label='Validation Epochs',
85
- value=100,
86
- precision=0)
 
 
 
 
87
  gr.Markdown('''
88
  - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
89
  - It takes a few minutes to download the base model first.
90
- - Expected time to train a model for 300 steps: 20 minutes with T4, 8 minutes with A10G, (4 minutes with A100)
91
  - It takes a few minutes to upload your trained model.
92
  - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
93
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
@@ -128,6 +134,7 @@ def create_training_demo(trainer: Trainer,
128
  delete_existing_repo,
129
  upload_to,
130
  remove_gpu_after_training,
 
131
  ],
132
  outputs=output_message)
133
  return demo
 
13
 
14
  def create_training_demo(trainer: Trainer,
15
  pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ hf_token = os.getenv('HF_TOKEN')
17
  with gr.Blocks() as demo:
18
  with gr.Row():
19
  with gr.Column():
 
33
  max_lines=1)
34
  delete_existing_model = gr.Checkbox(
35
  label='Delete existing model of the same name',
36
+ value=False,
37
+ visible=False)
38
  validation_prompt = gr.Text(label='Validation Prompt')
39
  with gr.Box():
40
  gr.Markdown('Upload Settings')
 
65
  value='512',
66
  label='Resolution',
67
  visible=False)
68
+
69
+ token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=True if hf_token else False)
70
+ with gr.Accordion("Advanced settings", open=False):
71
+ num_training_steps = gr.Number(
72
+ label='Number of Training Steps', value=300, precision=0)
73
+ learning_rate = gr.Number(label='Learning Rate',
74
+ value=0.000035)
75
+ gradient_accumulation = gr.Number(
76
+ label='Number of Gradient Accumulation',
77
+ value=1,
78
+ precision=0)
79
+ seed = gr.Slider(label='Seed',
80
+ minimum=0,
81
+ maximum=100000,
82
+ step=1,
83
+ randomize=True,
84
+ value=0)
85
+ fp16 = gr.Checkbox(label='FP16', value=True)
86
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
87
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
88
+ value=1000,
89
+ precision=0)
90
+ validation_epochs = gr.Number(label='Validation Epochs',
91
+ value=100,
92
+ precision=0)
93
  gr.Markdown('''
94
  - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
95
  - It takes a few minutes to download the base model first.
96
+ - Expected time to train a model for 300 steps: ~20 minutes with T4
97
  - It takes a few minutes to upload your trained model.
98
  - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
99
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
 
134
  delete_existing_repo,
135
  upload_to,
136
  remove_gpu_after_training,
137
+ token
138
  ],
139
  outputs=output_message)
140
  return demo