ShaoTengLiu commited on
Commit
f527f9c
1 Parent(s): 7fef50a

update two buttons

Browse files
Files changed (2) hide show
  1. app_training.py +10 -1
  2. trainer.py +2 -1
app_training.py CHANGED
@@ -40,6 +40,15 @@ def create_training_demo(trainer: Trainer,
40
  value='512',
41
  label='Resolution',
42
  visible=False)
 
 
 
 
 
 
 
 
 
43
 
44
  input_token = gr.Text(label='Hugging Face Write Token',
45
  placeholder='',
@@ -153,7 +162,7 @@ def create_training_demo(trainer: Trainer,
153
  gradient_accumulation, seed, fp16, use_8bit_adam,
154
  checkpointing_steps, validation_epochs, upload_to_hub,
155
  use_private_repo, delete_existing_repo, upload_to,
156
- remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
157
  ],
158
  outputs=output_message)
159
  return demo
 
40
  value='512',
41
  label='Resolution',
42
  visible=False)
43
+ with gr.Row():
44
+ tuned_model = gr.Text(
45
+ label='Path to tuned model',
46
+ value='xxx/xxx,
47
+ max_lines=1)
48
+ resolution = gr.Dropdown(choices=['512', '768'],
49
+ value='512',
50
+ label='Resolution',
51
+ visible=False)
52
 
53
  input_token = gr.Text(label='Hugging Face Write Token',
54
  placeholder='',
 
162
  gradient_accumulation, seed, fp16, use_8bit_adam,
163
  checkpointing_steps, validation_epochs, upload_to_hub,
164
  use_private_repo, delete_existing_repo, upload_to,
165
+ remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2, tuned_model
166
  ],
167
  outputs=output_message)
168
  return demo
trainer.py CHANGED
@@ -207,6 +207,7 @@ class Trainer:
207
  blend_word_2: str,
208
  eq_params_1: str,
209
  eq_params_2: str,
 
210
  ) -> str:
211
  # if SPACE_ID == ORIGINAL_SPACE_ID:
212
  # raise gr.Error(
@@ -239,7 +240,7 @@ class Trainer:
239
  self.hf_token if self.hf_token else input_token)
240
 
241
  config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
242
- config.pretrained_model_path = self.download_base_model(base_model)
243
  config.output_dir = output_dir.as_posix()
244
  config.train_data.video_path = training_video.name # type: ignore
245
  config.train_data.prompt = training_prompt
 
207
  blend_word_2: str,
208
  eq_params_1: str,
209
  eq_params_2: str,
210
+ tuned_model: str = None,
211
  ) -> str:
212
  # if SPACE_ID == ORIGINAL_SPACE_ID:
213
  # raise gr.Error(
 
240
  self.hf_token if self.hf_token else input_token)
241
 
242
  config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
243
+ config.pretrained_model_path = self.download_base_model(tuned_model)
244
  config.output_dir = output_dir.as_posix()
245
  config.train_data.video_path = training_video.name # type: ignore
246
  config.train_data.prompt = training_prompt