from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling from datasets import load_dataset from model import SASOKModel, SASOKConfig from tokenizer import tokenizer dataset = load_dataset("wikitext", "wikitext-2-raw-v1") tokenized = dataset.map(lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True) config = SASOKConfig() model = SASOKModel(config) training_args = TrainingArguments( output_dir="./sasok_output", evaluation_strategy="steps", eval_steps=500, per_device_train_batch_size=4, num_train_epochs=3, save_steps=1000, logging_dir="./logs" ) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized["train"], eval_dataset=tokenized["validation"], tokenizer=tokenizer, data_collator=data_collator ) trainer.train()