robkaandorp
commited on
Commit
•
1d59dc1
1
Parent(s):
fe1a630
Update script for chat training
Browse files
train_csv_dataset_phi-2-super.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import time
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
|
|
|
4 |
from trl import SFTTrainer
|
5 |
from peft import LoraConfig, prepare_model_for_kbit_training
|
6 |
|
7 |
-
dataset = load_dataset()
|
8 |
|
9 |
if torch.cuda.is_available():
|
10 |
print("Cuda is available")
|
@@ -71,7 +72,19 @@ training_args = TrainingArguments(
|
|
71 |
)
|
72 |
|
73 |
def formatting_func(data):
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
trainer = SFTTrainer(
|
77 |
model=model,
|
@@ -81,7 +94,7 @@ trainer = SFTTrainer(
|
|
81 |
args=training_args,
|
82 |
max_seq_length=1024,
|
83 |
packing=True,
|
84 |
-
|
85 |
)
|
86 |
|
87 |
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
|
|
|
1 |
import time
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
|
4 |
+
from datasets import load_dataset
|
5 |
from trl import SFTTrainer
|
6 |
from peft import LoraConfig, prepare_model_for_kbit_training
|
7 |
|
8 |
+
dataset = load_dataset("csv", data_files="nowhere_training_input.csv", delimiter=";", split="train")
|
9 |
|
10 |
if torch.cuda.is_available():
|
11 |
print("Cuda is available")
|
|
|
72 |
)
|
73 |
|
74 |
def formatting_func(data):
|
75 |
+
# text = f"[INST] {data['prompt']} [/INST]{data['completion']}{tokenizer.eos_token} "
|
76 |
+
chat = [
|
77 |
+
{ "role": "user", "content": data['prompt'] },
|
78 |
+
{ "role": "assistant", "content": data['completion'] },
|
79 |
+
]
|
80 |
+
|
81 |
+
text = tokenizer.apply_chat_template(chat, tokenize=False)
|
82 |
+
print(text)
|
83 |
+
data['text'] = text
|
84 |
+
|
85 |
+
return data
|
86 |
+
|
87 |
+
dataset = dataset.map(formatting_func)
|
88 |
|
89 |
trainer = SFTTrainer(
|
90 |
model=model,
|
|
|
94 |
args=training_args,
|
95 |
max_seq_length=1024,
|
96 |
packing=True,
|
97 |
+
dataset_text_field="text"
|
98 |
)
|
99 |
|
100 |
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
|