bstraehle commited on
Commit
bf28b8c
·
verified ·
1 Parent(s): 98c38e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -5,7 +5,8 @@ import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
  #from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
 
9
 
10
  ACTION_1 = "Prompt base model"
11
  ACTION_2 = "Fine-tune base model"
@@ -79,7 +80,7 @@ def fine_tune_model(base_model_name, dataset_name):
79
 
80
  # Configure training arguments
81
 
82
- training_args = Seq2SeqTrainingArguments(
83
  output_dir=f"./{FT_MODEL_NAME}",
84
  num_train_epochs=3, # 37,500 steps
85
  max_steps=1, # overwrites num_train_epochs
@@ -110,11 +111,18 @@ def fine_tune_model(base_model_name, dataset_name):
110
  #print("### PEFT")
111
  #model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
112
  #print("###")
113
-
 
 
 
 
 
 
 
114
  # Create trainer
115
 
116
- trainer = Seq2SeqTrainer(
117
- model=model,
118
  args=training_args,
119
  train_dataset=train_dataset,
120
  eval_dataset=eval_dataset,
 
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
  #from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
8
+ from peft import PeftModel
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, Trainer, TrainingArguments, pipeline
10
 
11
  ACTION_1 = "Prompt base model"
12
  ACTION_2 = "Fine-tune base model"
 
80
 
81
  # Configure training arguments
82
 
83
+ training_args = TrainingArguments(
84
  output_dir=f"./{FT_MODEL_NAME}",
85
  num_train_epochs=3, # 37,500 steps
86
  max_steps=1, # overwrites num_train_epochs
 
111
  #print("### PEFT")
112
  #model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
113
  #print("###")
114
+
115
+ peft_model = PeftModel.from_pretrained(
116
+ BASE_MODEL_NAME,
117
+ tokenizer=tokenizer,
118
+ adapter_name="lora",
119
+ adapter_dim=16,
120
+ )
121
+
122
  # Create trainer
123
 
124
+ trainer = Trainer(
125
+ model=peft_model,
126
  args=training_args,
127
  train_dataset=train_dataset,
128
  eval_dataset=eval_dataset,