| print("Loading...") |
|
|
| import torch |
|
|
| torch.cuda.empty_cache() |
|
|
| from datasets import load_dataset |
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| DataCollatorForLanguageModeling, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| MODEL_NAME = "Pin-25M" |
| DATASET_ID = "starhopp3r/TinyChat" |
| MAX_LENGTH = 256 |
| BATCH_SIZE = 32 |
|
|
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| config = AutoConfig.from_pretrained( |
| "gpt2", |
| n_layer=12, |
| n_head=12, |
| n_embd=288, |
| n_inner=1152, |
| vocab_size=len(tokenizer), |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| model = AutoModelForCausalLM.from_config(config) |
|
|
| print(f"Model parameters: {model.num_parameters() / 1e6:.2f}M") |
|
|
| print("Loading dataset...") |
|
|
| dataset = load_dataset(DATASET_ID, split="train") |
|
|
| def tokenize_function(examples): |
| return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH) |
|
|
| tokenized_datasets = dataset.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=dataset.column_names, |
| num_proc=4 |
| ) |
|
|
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
| print("Setting up training arguments...") |
|
|
| training_args = TrainingArguments( |
| output_dir="./" + MODEL_NAME + "_checkpoints", |
| num_train_epochs=1, |
| max_steps=1500, |
| per_device_train_batch_size=BATCH_SIZE, |
| gradient_accumulation_steps=2, |
| learning_rate=5e-4, |
| weight_decay=0.01, |
| logging_steps=100, |
| save_steps=2500, |
| fp16=True, |
| push_to_hub=False, |
| report_to="none", |
| warmup_steps=500, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_datasets, |
| data_collator=data_collator, |
| ) |
|
|
| print("Starting training...") |
| trainer.train() |
|
|
| trainer.save_model("./" + MODEL_NAME + "-Final") |
| tokenizer.save_pretrained("./" + MODEL_NAME + "-Final") |
|
|
| def chat(prompt): |
| formatted_prompt = f"[INST] {prompt} [/INST]" |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda") |
| model.to("cuda") |
| |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=50, |
| temperature=0.7, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| print("\n--- Test Chat ---") |
| print(chat("Hello, how are you today?")) |