ShaoTengLiu commited on
Commit
8963583
1 Parent(s): ca3f715

ready to release

Browse files
Files changed (3) hide show
  1. app.py +6 -6
  2. app_training.py +8 -6
  3. trainer.py +2 -0
app.py CHANGED
@@ -59,12 +59,12 @@ pipe = InferencePipeline(HF_TOKEN)
59
  trainer = Trainer(HF_TOKEN)
60
 
61
  with gr.Blocks(css='style.css') as demo:
62
- # if SPACE_ID == ORIGINAL_SPACE_ID:
63
- # show_warning(SHARED_UI_WARNING)
64
- # elif not torch.cuda.is_available():
65
- # show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
- # elif (not 'T4' in GPU_DATA):
67
- # show_warning(INVALID_GPU_WARNING)
68
 
69
  gr.Markdown(TITLE)
70
  with gr.Tabs():
 
59
  trainer = Trainer(HF_TOKEN)
60
 
61
  with gr.Blocks(css='style.css') as demo:
62
+ if SPACE_ID == ORIGINAL_SPACE_ID:
63
+ show_warning(SHARED_UI_WARNING)
64
+ elif not torch.cuda.is_available():
65
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
+ elif (not 'T4' in GPU_DATA):
67
+ show_warning(INVALID_GPU_WARNING)
68
 
69
  gr.Markdown(TITLE)
70
  with gr.Tabs():
app_training.py CHANGED
@@ -43,7 +43,7 @@ def create_training_demo(trainer: Trainer,
43
  with gr.Row():
44
  tuned_model = gr.Text(
45
  label='Path to tuned model',
46
- value='xxx/xxx',
47
  max_lines=1)
48
  resolution = gr.Dropdown(choices=['512', '768'],
49
  value='512',
@@ -60,6 +60,8 @@ def create_training_demo(trainer: Trainer,
60
  precision=0)
61
  learning_rate = gr.Number(label='Learning Rate',
62
  value=0.000035)
 
 
63
  gradient_accumulation = gr.Number(
64
  label='Number of Gradient Accumulation',
65
  value=1,
@@ -78,7 +80,7 @@ def create_training_demo(trainer: Trainer,
78
  value=1000,
79
  precision=0)
80
  validation_epochs = gr.Number(
81
- label='Validation Epochs', value=100, precision=0)
82
  gr.Markdown('''
83
  - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
84
  - Expected time to complete: ~20 minutes with T4.
@@ -89,8 +91,8 @@ def create_training_demo(trainer: Trainer,
89
  with gr.Row():
90
  with gr.Column():
91
  gr.Markdown('Output Model')
92
- output_model_name = gr.Text(label='Name of your model',
93
- placeholder='The skiing man',
94
  max_lines=1)
95
  validation_prompt = gr.Text(
96
  label='Validation Prompt',
@@ -111,7 +113,7 @@ def create_training_demo(trainer: Trainer,
111
  eq_params_2 = gr.Text(
112
  label='reweight_value',
113
  placeholder=
114
- '8')
115
  with gr.Column():
116
  gr.Markdown('Upload Settings')
117
  with gr.Row():
@@ -162,7 +164,7 @@ def create_training_demo(trainer: Trainer,
162
  gradient_accumulation, seed, fp16, use_8bit_adam,
163
  checkpointing_steps, validation_epochs, upload_to_hub,
164
  use_private_repo, delete_existing_repo, upload_to,
165
- remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2, tuned_model
166
  ],
167
  outputs=output_message)
168
  return demo
 
43
  with gr.Row():
44
  tuned_model = gr.Text(
45
  label='Path to tuned model',
46
+ value='xxx/ski-lego',
47
  max_lines=1)
48
  resolution = gr.Dropdown(choices=['512', '768'],
49
  value='512',
 
60
  precision=0)
61
  learning_rate = gr.Number(label='Learning Rate',
62
  value=0.000035)
63
+ cross_replace = gr.Number(label='Cross attention replace ratio',
64
+ value=0.2)
65
  gradient_accumulation = gr.Number(
66
  label='Number of Gradient Accumulation',
67
  value=1,
 
80
  value=1000,
81
  precision=0)
82
  validation_epochs = gr.Number(
83
+ label='Validation Epochs', value=300, precision=0)
84
  gr.Markdown('''
85
  - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
86
  - Expected time to complete: ~20 minutes with T4.
 
91
  with gr.Row():
92
  with gr.Column():
93
  gr.Markdown('Output Model')
94
+ output_model_name = gr.Text(label='Path to save your tuned model',
95
+ placeholder='ski-lego',
96
  max_lines=1)
97
  validation_prompt = gr.Text(
98
  label='Validation Prompt',
 
113
  eq_params_2 = gr.Text(
114
  label='reweight_value',
115
  placeholder=
116
+ '4')
117
  with gr.Column():
118
  gr.Markdown('Upload Settings')
119
  with gr.Row():
 
164
  gradient_accumulation, seed, fp16, use_8bit_adam,
165
  checkpointing_steps, validation_epochs, upload_to_hub,
166
  use_private_repo, delete_existing_repo, upload_to,
167
+ remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2, tuned_model, cross_replace
168
  ],
169
  outputs=output_message)
170
  return demo
trainer.py CHANGED
@@ -217,6 +217,7 @@ class Trainer:
217
  eq_params_1: str,
218
  eq_params_2: str,
219
  tuned_model: str,
 
220
  ) -> str:
221
  # if SPACE_ID == ORIGINAL_SPACE_ID:
222
  # raise gr.Error(
@@ -280,6 +281,7 @@ class Trainer:
280
  config.is_word_swap = True
281
  else:
282
  config.is_word_swap = False
 
283
 
284
  config_path = output_dir / 'config.yaml'
285
  with open(config_path, 'w') as f:
 
217
  eq_params_1: str,
218
  eq_params_2: str,
219
  tuned_model: str,
220
+ cross_replace: float,
221
  ) -> str:
222
  # if SPACE_ID == ORIGINAL_SPACE_ID:
223
  # raise gr.Error(
 
281
  config.is_word_swap = True
282
  else:
283
  config.is_word_swap = False
284
+ config.cross_replace_steps = cross_replace
285
 
286
  config_path = output_dir / 'config.yaml'
287
  with open(config_path, 'w') as f: