Rahman Azhar
Switch to FLAN-T5 model for better accessibility
70ee247
import torch
from transformers import (
T5ForConditionalGeneration,
T5Tokenizer,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq
)
from datasets import load_dataset
import os
import json
from typing import Dict, List
class ItineraryDataset(torch.utils.data.Dataset):
def __init__(self, data_path: str, tokenizer, max_length: int = 512):
self.tokenizer = tokenizer
self.max_length = max_length
self.examples = self._load_data(data_path)
def _load_data(self, data_path: str) -> List[Dict]:
with open(data_path, 'r') as f:
return json.load(f)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
example = self.examples[idx]
prompt = f"""Generate a detailed travel itinerary for {example['destination']} for {example['duration']} days.
Preferences: {example['preferences']}
Budget: {example['budget']}"""
target = example['itinerary']
# Combine prompt and target with special tokens
combined = f"{prompt}\n{target}</s>"
# Tokenize
encodings = self.tokenizer(
combined,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt"
)
return {
"input_ids": encodings["input_ids"][0],
"attention_mask": encodings["attention_mask"][0],
"labels": encodings["input_ids"][0].clone()
}
def train_itinerary_model(
model_name: str = "google/flan-t5-base",
data_path: str = "data/itineraries.json",
output_dir: str = "output",
num_epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 2e-5,
):
# Initialize tokenizer and model
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(
model_name,
device_map="auto"
)
# Load dataset
dataset = ItineraryDataset(data_path, tokenizer)
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
warmup_steps=100,
logging_steps=10,
save_steps=100,
fp16=True,
report_to="tensorboard"
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=DataCollatorForSeq2Seq(tokenizer)
)
# Train the model
trainer.train()
# Save the model
trainer.save_model()
tokenizer.save_pretrained(output_dir)
if __name__ == "__main__":
train_itinerary_model()