import os import sys import torch import pandas as pd import streamlit as st from datetime import datetime from transformers import ( T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq ) from torch.utils.data import Dataset import random # Ensure reproducibility torch.manual_seed(42) random.seed(42) # Environment setup os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' class TravelDataset(Dataset): def __init__(self, data, tokenizer, max_length=512): self.tokenizer = tokenizer self.data = data self.max_length = max_length print(f"Dataset loaded with {len(data)} samples") print("Columns:", list(data.columns)) def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data.iloc[idx] # Input: query input_text = row['query'] # Target: reference_information target_text = row['reference_information'] # Tokenize inputs input_encodings = self.tokenizer( input_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) # Tokenize targets target_encodings = self.tokenizer( target_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': input_encodings['input_ids'].squeeze(), 'attention_mask': input_encodings['attention_mask'].squeeze(), 'labels': target_encodings['input_ids'].squeeze() } def load_dataset(): """ Load the travel planning dataset from CSV. """ try: data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv") required_columns = ['query', 'reference_information'] for col in required_columns: if col not in data.columns: raise ValueError(f"Missing required column: {col}") print(f"Dataset loaded successfully with {len(data)} rows.") return data except Exception as e: print(f"Error loading dataset: {e}") sys.exit(1) def train_model(): try: # Load dataset data = load_dataset() # Initialize model and tokenizer print("Initializing T5 model and tokenizer...") tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False) model = T5ForConditionalGeneration.from_pretrained('t5-base') # Split data train_size = int(0.8 * len(data)) train_data = data[:train_size] val_data = data[train_size:] train_dataset = TravelDataset(train_data, tokenizer) val_dataset = TravelDataset(val_data, tokenizer) training_args = TrainingArguments( output_dir="./trained_travel_planner", num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, evaluation_strategy="steps", eval_steps=50, save_steps=100, weight_decay=0.01, logging_dir="./logs", logging_steps=10, load_best_model_at_end=True, ) data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, model=model, padding=True ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator ) print("Training model...") trainer.train() model.save_pretrained("./trained_travel_planner") tokenizer.save_pretrained("./trained_travel_planner") print("Model training complete!") return model, tokenizer except Exception as e: print(f"Training error: {e}") return None, None def generate_travel_plan(query, model, tokenizer): """ Generate a travel plan using the trained model. """ try: inputs = tokenizer( query, return_tensors="pt", max_length=512, padding="max_length", truncation=True ) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} model = model.cuda() outputs = model.generate( **inputs, max_length=512, num_beams=4, no_repeat_ngram_size=3, num_return_sequences=1 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"Error generating travel plan: {e}" def main(): st.set_page_config( page_title="AI Travel Planner", page_icon="✈️", layout="wide" ) st.title("✈️ AI Travel Planner") # Sidebar to train model with st.sidebar: st.header("Model Management") if st.button("Retrain Model"): with st.spinner("Training the model..."): model, tokenizer = train_model() if model: st.session_state['model'] = model st.session_state['tokenizer'] = tokenizer st.success("Model retrained successfully!") else: st.error("Model retraining failed.") # Load model if not already loaded if 'model' not in st.session_state: with st.spinner("Loading model..."): model, tokenizer = train_model() st.session_state['model'] = model st.session_state['tokenizer'] = tokenizer # Input query st.subheader("Plan Your Trip") query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')") if st.button("Generate Plan"): if not query: st.error("Please enter a query.") else: with st.spinner("Generating your travel plan..."): travel_plan = generate_travel_plan( query, st.session_state['model'], st.session_state['tokenizer'] ) st.subheader("Your Travel Plan") st.write(travel_plan) if __name__ == "__main__": main()