|
|
|
|
|
|
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
Trainer, |
|
TrainingArguments, |
|
DataCollatorForLanguageModeling, |
|
) |
|
from datasets import load_dataset, concatenate_datasets |
|
import logging |
|
|
|
|
|
|
|
from wersa import WersaConfig, WersaForCausalLM |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def main(): |
|
|
|
logger.info("Setting up ~0.6B model and tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B") |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
config = WersaConfig( |
|
vocab_size=len(tokenizer), |
|
pad_token_id=tokenizer.pad_token_id, |
|
hidden_size=1024, |
|
num_hidden_layers=24, |
|
num_attention_heads=16, |
|
intermediate_size=2816, |
|
max_position_embeddings=1024, |
|
wersa_decomp_levels=4, |
|
wersa_random_features=128, |
|
) |
|
|
|
model = WersaForCausalLM(config) |
|
logger.info(f"Model created with approximately {model.num_parameters() / 1e9:.2f}B parameters.") |
|
|
|
|
|
logger.info("Loading and preparing multiple datasets...") |
|
|
|
|
|
wikitext_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") |
|
wikitext_dataset = wikitext_dataset.filter(lambda x: len(x['text'].strip()) > 0) |
|
|
|
|
|
squad_dataset = load_dataset("squad", split="train") |
|
|
|
|
|
def format_squad(example): |
|
|
|
if example['answers']['text'] and len(example['answers']['text'][0]) > 0: |
|
return {'text': f"Context: {example['context']}\nQuestion: {example['question']}\nAnswer: {example['answers']['text'][0]}"} |
|
return {'text': ""} |
|
|
|
squad_dataset = squad_dataset.map(format_squad, remove_columns=squad_dataset.column_names) |
|
squad_dataset = squad_dataset.filter(lambda x: len(x['text'].strip()) > 0) |
|
|
|
|
|
raw_dataset = concatenate_datasets([wikitext_dataset, squad_dataset]).shuffle(seed=42) |
|
logger.info(f"Combined dataset size: {len(raw_dataset)} samples.") |
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples["text"], truncation=True, max_length=config.max_position_embeddings) |
|
|
|
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"]) |
|
|
|
block_size = config.max_position_embeddings |
|
def group_texts(examples): |
|
concatenated = {k: sum(examples[k], []) for k in examples.keys()} |
|
total_length = len(concatenated[list(examples.keys())[0]]) |
|
total_length = (total_length // block_size) * block_size |
|
result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated.items()} |
|
result["labels"] = result["input_ids"].copy() |
|
return result |
|
|
|
lm_dataset = tokenized_dataset.map(group_texts, batched=True, num_proc=4) |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
|
|
output_dir = "./wersa-qwen-style-0.6b-final" |
|
training_args = TrainingArguments( |
|
output_dir=output_dir, |
|
overwrite_output_dir=True, |
|
num_train_epochs=1, |
|
per_device_train_batch_size=1, |
|
gradient_accumulation_steps=32, |
|
save_steps=500, |
|
save_total_limit=2, |
|
logging_steps=50, |
|
fp16=torch.cuda.is_available(), |
|
) |
|
|
|
trainer = Trainer(model=model, args=training_args, train_dataset=lm_dataset, data_collator=data_collator) |
|
|
|
logger.info("Starting pre-training for the ~0.6B model...") |
|
trainer.train() |
|
trainer.save_model(output_dir) |
|
tokenizer.save_pretrained(output_dir) |
|
logger.info(f"Model and tokenizer saved to {output_dir}") |
|
|
|
|
|
logger.info("\n" + "="*50 + "\n RUNNING GENERATION TEST\n" + "="*50 + "\n") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
prompt = "What is the meaning of life?" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
logger.info(f"PROMPT: '{prompt}'") |
|
|
|
|
|
trained_model = WersaForCausalLM.from_pretrained(output_dir) |
|
trained_model.to(device) |
|
|
|
outputs = trained_model.generate(**inputs, max_new_tokens=100, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
logger.info("\nMODEL COMPLETION:\n") |
|
print(generated_text) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|