LH-Tech-AI commited on
Commit
859492a
·
verified ·
1 Parent(s): 0ccc517

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +101 -0
train.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Loading...")
2
+
3
+ import torch
4
+
5
+ torch.cuda.empty_cache()
6
+
7
+ from datasets import load_dataset
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ DataCollatorForLanguageModeling,
13
+ Trainer,
14
+ TrainingArguments,
15
+ )
16
+
17
+ MODEL_NAME = "Pin-25M"
18
+ DATASET_ID = "starhopp3r/TinyChat"
19
+ MAX_LENGTH = 256
20
+ BATCH_SIZE = 32
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+ config = AutoConfig.from_pretrained(
26
+ "gpt2",
27
+ n_layer=12,
28
+ n_head=12,
29
+ n_embd=288,
30
+ n_inner=1152,
31
+ vocab_size=len(tokenizer),
32
+ bos_token_id=tokenizer.bos_token_id,
33
+ eos_token_id=tokenizer.eos_token_id,
34
+ )
35
+ model = AutoModelForCausalLM.from_config(config)
36
+
37
+ print(f"Model parameters: {model.num_parameters() / 1e6:.2f}M")
38
+
39
+ print("Loading dataset...")
40
+
41
+ dataset = load_dataset(DATASET_ID, split="train")
42
+
43
+ def tokenize_function(examples):
44
+ return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH)
45
+
46
+ tokenized_datasets = dataset.map(
47
+ tokenize_function,
48
+ batched=True,
49
+ remove_columns=dataset.column_names,
50
+ num_proc=4
51
+ )
52
+
53
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
54
+
55
+ print("Setting up training arguments...")
56
+
57
+ training_args = TrainingArguments(
58
+ output_dir="./" + MODEL_NAME + "_checkpoints",
59
+ num_train_epochs=1,
60
+ max_steps=1500,
61
+ per_device_train_batch_size=BATCH_SIZE,
62
+ gradient_accumulation_steps=2,
63
+ learning_rate=5e-4,
64
+ weight_decay=0.01,
65
+ logging_steps=100,
66
+ save_steps=2500,
67
+ fp16=True,
68
+ push_to_hub=False,
69
+ report_to="none",
70
+ warmup_steps=500,
71
+ )
72
+
73
+ trainer = Trainer(
74
+ model=model,
75
+ args=training_args,
76
+ train_dataset=tokenized_datasets,
77
+ data_collator=data_collator,
78
+ )
79
+
80
+ print("Starting training...")
81
+ trainer.train()
82
+
83
+ trainer.save_model("./" + MODEL_NAME + "-Final")
84
+ tokenizer.save_pretrained("./" + MODEL_NAME + "-Final")
85
+
86
+ def chat(prompt):
87
+ formatted_prompt = f"[INST] {prompt} [/INST]"
88
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda")
89
+ model.to("cuda")
90
+
91
+ outputs = model.generate(
92
+ **inputs,
93
+ max_new_tokens=50,
94
+ temperature=0.7,
95
+ do_sample=True,
96
+ pad_token_id=tokenizer.eos_token_id
97
+ )
98
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+
100
+ print("\n--- Test Chat ---")
101
+ print(chat("Hello, how are you today?"))