wangjin2000 commited on
Commit
36de84f
·
verified ·
1 Parent(s): b7e3e8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -196,11 +196,12 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
196
  train_dataset = accelerator.prepare(train_dataset)
197
  test_dataset = accelerator.prepare(test_dataset)
198
 
 
199
  timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
200
 
201
  # Training setup
202
  training_args = TrainingArguments(
203
- output_dir=f"esm2_t12_35M-lora-binding-sites_{timestamp}",
204
  learning_rate=config["lr"],
205
  lr_scheduler_type=config["lr_scheduler_type"],
206
  gradient_accumulation_steps=1,
@@ -241,9 +242,9 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
241
 
242
  # Train and Save Model
243
  trainer.train()
244
- save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
245
- trainer.save_model(save_path)
246
- tokenizer.save_pretrained(save_path)
247
 
248
  return save_path
249
 
 
196
  train_dataset = accelerator.prepare(train_dataset)
197
  test_dataset = accelerator.prepare(test_dataset)
198
 
199
+ model_name_base = base_model_path.split("/")[1]
200
  timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
201
 
202
  # Training setup
203
  training_args = TrainingArguments(
204
+ output_dir=f"{model_name_base}-lora-binding-sites_{timestamp}",
205
  learning_rate=config["lr"],
206
  lr_scheduler_type=config["lr_scheduler_type"],
207
  gradient_accumulation_steps=1,
 
242
 
243
  # Train and Save Model
244
  trainer.train()
245
+ #save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
246
+ #trainer.save_model(save_path)
247
+ #tokenizer.save_pretrained(save_path)
248
 
249
  return save_path
250