| | """ |
| | Fine-tuning script for Kat-Gen1 model |
| | """ |
| |
|
| | import torch |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | Trainer, |
| | TrainingArguments, |
| | DataCollatorForLanguageModeling |
| | ) |
| | from datasets import load_dataset |
| | from typing import Optional |
| |
|
| |
|
| | class KatGen1Trainer: |
| | def __init__( |
| | self, |
| | model_name: str = "Katisim/Kat-Gen1", |
| | output_dir: str = "./kat-gen1-finetuned" |
| | ): |
| | """ |
| | Initialize the training setup. |
| | |
| | Args: |
| | model_name: Base model to fine-tune |
| | output_dir: Directory to save fine-tuned model |
| | """ |
| | self.model_name = model_name |
| | self.output_dir = output_dir |
| | self.model = None |
| | self.tokenizer = None |
| | |
| | def load_model(self): |
| | """Load model and tokenizer.""" |
| | self.model = AutoModelForCausalLM.from_pretrained(self.model_name) |
| | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| | |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.model.config.pad_token_id = self.tokenizer.pad_token_id |
| | |
| | def prepare_dataset( |
| | self, |
| | dataset_name: str, |
| | text_column: str = "text", |
| | max_length: int = 512 |
| | ): |
| | """ |
| | Prepare dataset for training. |
| | |
| | Args: |
| | dataset_name: Name of dataset from HuggingFace Hub |
| | text_column: Column name containing text data |
| | max_length: Maximum sequence length |
| | |
| | Returns: |
| | Tokenized dataset |
| | """ |
| | dataset = load_dataset(dataset_name) |
| | |
| | def tokenize_function(examples): |
| | return self.tokenizer( |
| | examples[text_column], |
| | truncation=True, |
| | max_length=max_length, |
| | padding="max_length" |
| | ) |
| | |
| | tokenized_dataset = dataset.map( |
| | tokenize_function, |
| | batched=True, |
| | remove_columns=dataset["train"].column_names |
| | ) |
| | |
| | return tokenized_dataset |
| | |
| | def train( |
| | self, |
| | train_dataset, |
| | eval_dataset: Optional = None, |
| | num_train_epochs: int = 3, |
| | per_device_train_batch_size: int = 4, |
| | per_device_eval_batch_size: int = 4, |
| | learning_rate: float = 5e-5, |
| | warmup_steps: int = 500, |
| | weight_decay: float = 0.01, |
| | logging_steps: int = 100, |
| | save_steps: int = 1000, |
| | eval_steps: int = 500 |
| | ): |
| | """ |
| | Fine-tune the model. |
| | |
| | Args: |
| | train_dataset: Training dataset |
| | eval_dataset: Evaluation dataset (optional) |
| | num_train_epochs: Number of training epochs |
| | per_device_train_batch_size: Training batch size per device |
| | per_device_eval_batch_size: Evaluation batch size per device |
| | learning_rate: Learning rate |
| | warmup_steps: Number of warmup steps |
| | weight_decay: Weight decay coefficient |
| | logging_steps: Log every N steps |
| | save_steps: Save checkpoint every N steps |
| | eval_steps: Evaluate every N steps |
| | """ |
| | training_args = TrainingArguments( |
| | output_dir=self.output_dir, |
| | num_train_epochs=num_train_epochs, |
| | per_device_train_batch_size=per_device_train_batch_size, |
| | per_device_eval_batch_size=per_device_eval_batch_size, |
| | learning_rate=learning_rate, |
| | warmup_steps=warmup_steps, |
| | weight_decay=weight_decay, |
| | logging_dir=f"{self.output_dir}/logs", |
| | logging_steps=logging_steps, |
| | save_steps=save_steps, |
| | eval_steps=eval_steps if eval_dataset else None, |
| | evaluation_strategy="steps" if eval_dataset else "no", |
| | save_total_limit=3, |
| | fp16=torch.cuda.is_available(), |
| | gradient_accumulation_steps=4, |
| | load_best_model_at_end=True if eval_dataset else False |
| | ) |
| | |
| | data_collator = DataCollatorForLanguageModeling( |
| | tokenizer=self.tokenizer, |
| | mlm=False |
| | ) |
| | |
| | trainer = Trainer( |
| | model=self.model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | data_collator=data_collator |
| | ) |
| | |
| | trainer.train() |
| | trainer.save_model(self.output_dir) |
| | self.tokenizer.save_pretrained(self.output_dir) |
| |
|
| |
|
| | def main(): |
| | """Example training workflow.""" |
| | trainer = KatGen1Trainer(output_dir="./kat-gen1-custom") |
| | trainer.load_model() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print("Training setup complete. Uncomment dataset loading to begin training.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |