|
import torch |
|
from transformers import TrainingArguments, MistralForCausalLM, MistralModel, MistralConfig, AutoTokenizer |
|
from datasets import load_dataset |
|
from trl import SFTTrainer |
|
|
|
configuration = MistralConfig(vocab_size=32000, |
|
hidden_size=2048, |
|
intermediate_size=7168, |
|
num_hidden_layers=24, |
|
num_attention_heads=32, |
|
num_key_value_heads=8, |
|
hidden_act="silu", |
|
max_position_embeddings=4096, |
|
pad_token_id=2, |
|
bos_token_id=1, |
|
eos_token_id=2) |
|
|
|
model = MistralForCausalLM(configuration) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", local_files_only=False) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
dataset = load_dataset('HuggingFaceTB/cosmopedia-20k', split="train") |
|
|
|
dataset = dataset.shuffle(seed=42) |
|
print(f'Number of prompts: {len(dataset)}') |
|
print(f'Column names are: {dataset.column_names}') |
|
|
|
def create_prompt_formats(sample): |
|
""" |
|
Format various fields of the sample ('instruction', 'context', 'response') |
|
Then concatenate them using two newline characters |
|
:param sample: Sample dictionnary |
|
""" |
|
output_texts = [] |
|
for i in range(len(sample['text'])): |
|
formatted_prompt = sample['text'][i] |
|
output_texts.append(formatted_prompt) |
|
|
|
return output_texts |
|
|
|
|
|
trainer = SFTTrainer( |
|
model, |
|
train_dataset=dataset, |
|
tokenizer = tokenizer, |
|
max_seq_length=2048, |
|
formatting_func=create_prompt_formats, |
|
args=TrainingArguments( |
|
per_device_train_batch_size=2, |
|
gradient_accumulation_steps=1, |
|
warmup_steps=2, |
|
max_steps=10000, |
|
learning_rate=1e-4, |
|
logging_steps=1, |
|
output_dir="6B_outputs", overwrite_output_dir=True,save_steps=1000, |
|
optim="paged_adamw_32bit",report_to="none" |
|
) |
|
) |
|
trainer.train() |
|
trainer.model.save_pretrained("6B-final", dtype=torch.float32) |
|
trainer.tokenizer.save_pretrained("6B-final") |
|
|