tykiww commited on
Commit
a30c450
1 Parent(s): df4fd98

Update utilities/modeling.py

Browse files
Files changed (1) hide show
  1. utilities/modeling.py +6 -2
utilities/modeling.py CHANGED
@@ -37,7 +37,9 @@ def get_peft(model, peft, max_seq_length, random_seed):
37
  return model
38
 
39
 
40
- def get_trainer(model, tokenizer, dataset, sft, data_field, max_seq_length, random_seed):
 
 
41
 
42
  trainer = SFTTrainer(
43
  model = model,
@@ -68,6 +70,7 @@ def get_trainer(model, tokenizer, dataset, sft, data_field, max_seq_length, rand
68
 
69
 
70
  def prepare_trainer(model_name, max_seq_length, random_seed,
 
71
  peft, sft, dataset, data_field):
72
 
73
  print("Loading Model")
@@ -77,7 +80,8 @@ def prepare_trainer(model_name, max_seq_length, random_seed,
77
  model = get_peft(model, peft, max_seq_length, random_seed)
78
 
79
  print("Getting Trainer Model")
80
- trainer = get_trainer(model, tokenizer, dataset, data_field, max_seq_length, random_seed)
 
81
 
82
  return trainer
83
 
 
37
  return model
38
 
39
 
40
+ def get_trainer(model, tokenizer, dataset, sft,
41
+ data_field, max_seq_length, random_seed,
42
+ num_epochs, max_steps):
43
 
44
  trainer = SFTTrainer(
45
  model = model,
 
70
 
71
 
72
  def prepare_trainer(model_name, max_seq_length, random_seed,
73
+ num_epochs, max_steps,
74
  peft, sft, dataset, data_field):
75
 
76
  print("Loading Model")
 
80
  model = get_peft(model, peft, max_seq_length, random_seed)
81
 
82
  print("Getting Trainer Model")
83
+ trainer = get_trainer(model, tokenizer, dataset, data_field, max_seq_length, random_seed,
84
+ num_epochs, max_steps)
85
 
86
  return trainer
87