cloudyu commited on
Commit
94ad50a
1 Parent(s): 6841785

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +67 -1
README.md CHANGED
@@ -1,4 +1,70 @@
1
  ---
2
  license: apache-2.0
3
  ---
4
- this is a demo how to pretrain a mistral architecture model by SFT Trainer within tens of lines Python code.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+ This is a demo of how to pretrain a mistral architecture model by SFT Trainer ,and it needs only 70 lines Python code.
5
+
6
+ ```
7
+ import torch
8
+ from transformers import TrainingArguments, MistralForCausalLM, MistralModel, MistralConfig, AutoTokenizer
9
+ from datasets import load_dataset
10
+ from trl import SFTTrainer
11
+
12
+ configuration = MistralConfig(vocab_size=32000,
13
+ hidden_size=2048,
14
+ intermediate_size=7168,
15
+ num_hidden_layers=24,
16
+ num_attention_heads=32,
17
+ num_key_value_heads=8,
18
+ hidden_act="silu",
19
+ max_position_embeddings=4096,
20
+ pad_token_id=2,
21
+ bos_token_id=1,
22
+ eos_token_id=2)
23
+
24
+ model = MistralForCausalLM(configuration)
25
+ #model = MistralForCausalLM.from_pretrained("./6B_code_outputs/checkpoint-10000")
26
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", local_files_only=False)
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+
29
+ dataset = load_dataset('HuggingFaceTB/cosmopedia-20k', split="train")
30
+ #dataset = load_dataset('Elriggs/openwebtext-100k', split="train")
31
+ dataset = dataset.shuffle(seed=42)
32
+ print(f'Number of prompts: {len(dataset)}')
33
+ print(f'Column names are: {dataset.column_names}')
34
+
35
+ def create_prompt_formats(sample):
36
+ """
37
+ Format various fields of the sample ('instruction', 'context', 'response')
38
+ Then concatenate them using two newline characters
39
+ :param sample: Sample dictionnary
40
+ """
41
+ output_texts = []
42
+ for i in range(len(sample['text'])):
43
+ formatted_prompt = sample['text'][i]
44
+ output_texts.append(formatted_prompt)
45
+ #print(output_texts)
46
+ return output_texts
47
+
48
+
49
+ trainer = SFTTrainer(
50
+ model,
51
+ train_dataset=dataset,
52
+ tokenizer = tokenizer,
53
+ max_seq_length=2048,
54
+ formatting_func=create_prompt_formats,
55
+ args=TrainingArguments(
56
+ per_device_train_batch_size=2,
57
+ gradient_accumulation_steps=1,
58
+ warmup_steps=2,
59
+ max_steps=10000,
60
+ learning_rate=1e-4,
61
+ logging_steps=1,
62
+ output_dir="6B_outputs", overwrite_output_dir=True,save_steps=1000,
63
+ optim="paged_adamw_32bit",report_to="none"
64
+ )
65
+ )
66
+ trainer.train()
67
+ trainer.model.save_pretrained("6B-final", dtype=torch.float32)
68
+ trainer.tokenizer.save_pretrained("6B-final")
69
+
70
+ ```