wersa / train_and_generate_0.6b.py
vincenzodentamaro's picture
Create train_and_generate_0.6b.py
09ab73a verified
raw
history blame
5.05 kB
#!/usr/bin/env python
# train_and_generate_multi_dataset.py
import torch
from transformers import (
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
from datasets import load_dataset, concatenate_datasets
import logging
# Import the custom WERSA classes from your local package
# This assumes you have run `pip install -e .` with the corrected modeling file
from wersa import WersaConfig, WersaForCausalLM
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
# --- 1. Configuration for ~0.6B Model ---
logger.info("Setting up ~0.6B model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Configuration for a ~0.6B parameter model to fit in memory
config = WersaConfig(
vocab_size=len(tokenizer),
pad_token_id=tokenizer.pad_token_id,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=2816,
max_position_embeddings=1024,
wersa_decomp_levels=4,
wersa_random_features=128,
)
model = WersaForCausalLM(config)
logger.info(f"Model created with approximately {model.num_parameters() / 1e9:.2f}B parameters.")
# --- 2. Dataset Preparation ---
logger.info("Loading and preparing multiple datasets...")
# Dataset 1: WikiText (formal, structured text)
wikitext_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
wikitext_dataset = wikitext_dataset.filter(lambda x: len(x['text'].strip()) > 0)
# Dataset 2: SQuAD (conversational Q&A) - replacing the old eli5 dataset
squad_dataset = load_dataset("squad", split="train")
# Format SQuAD to have a single 'text' column
def format_squad(example):
# Combine context, question, and the first answer into a single text field
if example['answers']['text'] and len(example['answers']['text'][0]) > 0:
return {'text': f"Context: {example['context']}\nQuestion: {example['question']}\nAnswer: {example['answers']['text'][0]}"}
return {'text': ""}
squad_dataset = squad_dataset.map(format_squad, remove_columns=squad_dataset.column_names)
squad_dataset = squad_dataset.filter(lambda x: len(x['text'].strip()) > 0)
# Combine the two datasets
raw_dataset = concatenate_datasets([wikitext_dataset, squad_dataset]).shuffle(seed=42)
logger.info(f"Combined dataset size: {len(raw_dataset)} samples.")
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, max_length=config.max_position_embeddings)
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
block_size = config.max_position_embeddings
def group_texts(examples):
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated[list(examples.keys())[0]])
total_length = (total_length // block_size) * block_size
result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated.items()}
result["labels"] = result["input_ids"].copy()
return result
lm_dataset = tokenized_dataset.map(group_texts, batched=True, num_proc=4)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# --- 3. Training ---
output_dir = "./wersa-qwen-style-0.6b-final"
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=32,
save_steps=500,
save_total_limit=2,
logging_steps=50,
fp16=torch.cuda.is_available(),
)
trainer = Trainer(model=model, args=training_args, train_dataset=lm_dataset, data_collator=data_collator)
logger.info("Starting pre-training for the ~0.6B model...")
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
logger.info(f"Model and tokenizer saved to {output_dir}")
# --- 4. Generation Test ---
logger.info("\n" + "="*50 + "\n RUNNING GENERATION TEST\n" + "="*50 + "\n")
device = "cuda" if torch.cuda.is_available() else "cpu"
prompt = "What is the meaning of life?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
logger.info(f"PROMPT: '{prompt}'")
# Load the trained model for generation
trained_model = WersaForCausalLM.from_pretrained(output_dir)
trained_model.to(device)
outputs = trained_model.generate(**inputs, max_new_tokens=100, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("\nMODEL COMPLETION:\n")
print(generated_text)
if __name__ == "__main__":
main()