import torch from transformers import TrainingArguments, MistralForCausalLM, MistralModel, MistralConfig, AutoTokenizer from datasets import load_dataset from trl import SFTTrainer configuration = MistralConfig(vocab_size=32000, hidden_size=2048, intermediate_size=7168, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=8, hidden_act="silu", max_position_embeddings=4096, pad_token_id=2, bos_token_id=1, eos_token_id=2) model = MistralForCausalLM(configuration) #model = MistralForCausalLM.from_pretrained("./6B_code_outputs/checkpoint-10000") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", local_files_only=False) tokenizer.pad_token = tokenizer.eos_token dataset = load_dataset('HuggingFaceTB/cosmopedia-20k', split="train") #dataset = load_dataset('Elriggs/openwebtext-100k', split="train") dataset = dataset.shuffle(seed=42) print(f'Number of prompts: {len(dataset)}') print(f'Column names are: {dataset.column_names}') def create_prompt_formats(sample): """ Format various fields of the sample ('instruction', 'context', 'response') Then concatenate them using two newline characters :param sample: Sample dictionnary """ output_texts = [] for i in range(len(sample['text'])): formatted_prompt = sample['text'][i] output_texts.append(formatted_prompt) #print(output_texts) return output_texts trainer = SFTTrainer( model, train_dataset=dataset, tokenizer = tokenizer, max_seq_length=2048, formatting_func=create_prompt_formats, args=TrainingArguments( per_device_train_batch_size=2, gradient_accumulation_steps=1, warmup_steps=2, max_steps=10000, learning_rate=1e-4, logging_steps=1, output_dir="6B_outputs", overwrite_output_dir=True,save_steps=1000, optim="paged_adamw_32bit",report_to="none" ) ) trainer.train() trainer.model.save_pretrained("6B-final", dtype=torch.float32) trainer.tokenizer.save_pretrained("6B-final")