robkaandorp commited on
Commit
1d59dc1
1 Parent(s): fe1a630

Update script for chat training

Browse files
Files changed (1) hide show
  1. train_csv_dataset_phi-2-super.py +16 -3
train_csv_dataset_phi-2-super.py CHANGED
@@ -1,10 +1,11 @@
1
  import time
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
 
4
  from trl import SFTTrainer
5
  from peft import LoraConfig, prepare_model_for_kbit_training
6
 
7
- dataset = load_dataset()
8
 
9
  if torch.cuda.is_available():
10
  print("Cuda is available")
@@ -71,7 +72,19 @@ training_args = TrainingArguments(
71
  )
72
 
73
  def formatting_func(data):
74
- return f"[INST] {data['prompt']} [/INST]{data['completion']}{tokenizer.eos_token} "
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  trainer = SFTTrainer(
77
  model=model,
@@ -81,7 +94,7 @@ trainer = SFTTrainer(
81
  args=training_args,
82
  max_seq_length=1024,
83
  packing=True,
84
- formatting_func=formatting_func
85
  )
86
 
87
  model.config.use_cache = False # silence the warnings. Please re-enable for inference!
 
1
  import time
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
4
+ from datasets import load_dataset
5
  from trl import SFTTrainer
6
  from peft import LoraConfig, prepare_model_for_kbit_training
7
 
8
+ dataset = load_dataset("csv", data_files="nowhere_training_input.csv", delimiter=";", split="train")
9
 
10
  if torch.cuda.is_available():
11
  print("Cuda is available")
 
72
  )
73
 
74
  def formatting_func(data):
75
+ # text = f"[INST] {data['prompt']} [/INST]{data['completion']}{tokenizer.eos_token} "
76
+ chat = [
77
+ { "role": "user", "content": data['prompt'] },
78
+ { "role": "assistant", "content": data['completion'] },
79
+ ]
80
+
81
+ text = tokenizer.apply_chat_template(chat, tokenize=False)
82
+ print(text)
83
+ data['text'] = text
84
+
85
+ return data
86
+
87
+ dataset = dataset.map(formatting_func)
88
 
89
  trainer = SFTTrainer(
90
  model=model,
 
94
  args=training_args,
95
  max_seq_length=1024,
96
  packing=True,
97
+ dataset_text_field="text"
98
  )
99
 
100
  model.config.use_cache = False # silence the warnings. Please re-enable for inference!