|
|
import json
|
|
|
from pathlib import Path
|
|
|
import logging
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from transformers import (
|
|
|
AutoModelForCausalLM,
|
|
|
AutoTokenizer,
|
|
|
TrainingArguments,
|
|
|
Trainer,
|
|
|
DataCollatorForLanguageModeling
|
|
|
)
|
|
|
import wandb
|
|
|
import numpy as np
|
|
|
from datasets import load_dataset
|
|
|
from typing import Dict, List, Any
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BengaliCodeDataset(Dataset):
|
|
|
def __init__(self, data_path: Path, tokenizer, max_length: int = 2048):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_length = max_length
|
|
|
|
|
|
|
|
|
with open(data_path, 'r', encoding='utf-8') as f:
|
|
|
self.data = json.load(f)
|
|
|
|
|
|
logger.info(f"Loaded {len(self.data)} examples from {data_path}")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
|
|
|
item = self.data[idx]
|
|
|
text = item['text']
|
|
|
|
|
|
|
|
|
encodings = self.tokenizer(
|
|
|
text,
|
|
|
max_length=self.max_length,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
return_tensors='pt'
|
|
|
)
|
|
|
|
|
|
|
|
|
labels = encodings.input_ids.clone()
|
|
|
|
|
|
|
|
|
attention_mask = encodings.attention_mask
|
|
|
|
|
|
return {
|
|
|
'input_ids': encodings.input_ids[0],
|
|
|
'attention_mask': attention_mask[0],
|
|
|
'labels': labels[0]
|
|
|
}
|
|
|
|
|
|
class ModelTrainer:
|
|
|
def __init__(self):
|
|
|
self.data_dir = Path('data/raw')
|
|
|
self.tokenizer_dir = Path('outputs/tokenizer')
|
|
|
self.output_dir = Path('outputs/model')
|
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
self.model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
|
|
|
self.max_length = 2048
|
|
|
self.batch_size = 4
|
|
|
self.gradient_accumulation_steps = 4
|
|
|
self.learning_rate = 2e-5
|
|
|
self.num_train_epochs = 3
|
|
|
self.warmup_steps = 100
|
|
|
self.save_steps = 1000
|
|
|
self.eval_steps = 500
|
|
|
|
|
|
def setup_wandb(self):
|
|
|
"""Initialize Weights & Biases tracking"""
|
|
|
wandb.init(
|
|
|
project="bengali-code-llm",
|
|
|
name="tinyllama-bengali-code",
|
|
|
config={
|
|
|
"model_name": self.model_name,
|
|
|
"max_length": self.max_length,
|
|
|
"batch_size": self.batch_size,
|
|
|
"learning_rate": self.learning_rate,
|
|
|
"num_epochs": self.num_train_epochs
|
|
|
}
|
|
|
)
|
|
|
|
|
|
def prepare_model_and_tokenizer(self):
|
|
|
"""Load and prepare the model and tokenizer"""
|
|
|
logger.info("Loading tokenizer and model")
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
self.tokenizer_dir,
|
|
|
model_max_length=self.max_length
|
|
|
)
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
self.model_name,
|
|
|
trust_remote_code=True,
|
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
|
)
|
|
|
|
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def create_datasets(self, tokenizer):
|
|
|
"""Create training and validation datasets"""
|
|
|
logger.info("Creating datasets")
|
|
|
|
|
|
|
|
|
data_path = self.data_dir / 'processed_data.json'
|
|
|
|
|
|
|
|
|
with open(data_path, 'r', encoding='utf-8') as f:
|
|
|
all_data = json.load(f)
|
|
|
|
|
|
np.random.seed(42)
|
|
|
np.random.shuffle(all_data)
|
|
|
|
|
|
split_idx = int(len(all_data) * 0.9)
|
|
|
train_data = all_data[:split_idx]
|
|
|
val_data = all_data[split_idx:]
|
|
|
|
|
|
|
|
|
train_path = self.data_dir / 'train.json'
|
|
|
val_path = self.data_dir / 'validation.json'
|
|
|
|
|
|
with open(train_path, 'w', encoding='utf-8') as f:
|
|
|
json.dump(train_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
with open(val_path, 'w', encoding='utf-8') as f:
|
|
|
json.dump(val_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
train_dataset = BengaliCodeDataset(train_path, tokenizer, self.max_length)
|
|
|
val_dataset = BengaliCodeDataset(val_path, tokenizer, self.max_length)
|
|
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
def create_training_arguments(self):
|
|
|
"""Create training arguments for the Trainer"""
|
|
|
return TrainingArguments(
|
|
|
output_dir=str(self.output_dir),
|
|
|
num_train_epochs=self.num_train_epochs,
|
|
|
per_device_train_batch_size=self.batch_size,
|
|
|
per_device_eval_batch_size=self.batch_size,
|
|
|
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
|
|
evaluation_strategy="steps",
|
|
|
eval_steps=self.eval_steps,
|
|
|
save_strategy="steps",
|
|
|
save_steps=self.save_steps,
|
|
|
learning_rate=self.learning_rate,
|
|
|
warmup_steps=self.warmup_steps,
|
|
|
weight_decay=0.01,
|
|
|
logging_dir=str(self.output_dir / 'logs'),
|
|
|
logging_steps=100,
|
|
|
report_to="wandb",
|
|
|
save_total_limit=3,
|
|
|
load_best_model_at_end=True,
|
|
|
metric_for_best_model="eval_loss",
|
|
|
greater_is_better=False,
|
|
|
fp16=torch.cuda.is_available(),
|
|
|
remove_unused_columns=False
|
|
|
)
|
|
|
|
|
|
def train(self):
|
|
|
"""Main method to train the model"""
|
|
|
try:
|
|
|
|
|
|
self.setup_wandb()
|
|
|
|
|
|
|
|
|
model, tokenizer = self.prepare_model_and_tokenizer()
|
|
|
|
|
|
|
|
|
train_dataset, val_dataset = self.create_datasets(tokenizer)
|
|
|
|
|
|
|
|
|
training_args = self.create_training_arguments()
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(
|
|
|
tokenizer=tokenizer,
|
|
|
mlm=False
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model=model,
|
|
|
args=training_args,
|
|
|
train_dataset=train_dataset,
|
|
|
eval_dataset=val_dataset,
|
|
|
data_collator=data_collator,
|
|
|
tokenizer=tokenizer
|
|
|
)
|
|
|
|
|
|
|
|
|
logger.info("Starting model training")
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
|
trainer.save_model(str(self.output_dir / 'final'))
|
|
|
tokenizer.save_pretrained(str(self.output_dir / 'final'))
|
|
|
|
|
|
|
|
|
wandb.finish()
|
|
|
|
|
|
logger.info("Model training completed successfully")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Model training failed: {str(e)}")
|
|
|
raise
|
|
|
finally:
|
|
|
|
|
|
if wandb.run is not None:
|
|
|
wandb.finish()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
trainer = ModelTrainer()
|
|
|
trainer.train()
|
|
|
|