Spaces:
Running
Running
| import os | |
| import sys | |
| import torch | |
| import pandas as pd | |
| import streamlit as st | |
| from datetime import datetime | |
| from transformers import ( | |
| T5ForConditionalGeneration, | |
| T5Tokenizer, | |
| Trainer, | |
| TrainingArguments, | |
| DataCollatorForSeq2Seq | |
| ) | |
| from torch.utils.data import Dataset | |
| import random | |
| # Ensure reproducibility | |
| torch.manual_seed(42) | |
| random.seed(42) | |
| # Environment setup | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
| class TravelDataset(Dataset): | |
| def __init__(self, data, tokenizer, max_length=512): | |
| self.tokenizer = tokenizer | |
| self.data = data | |
| self.max_length = max_length | |
| print(f"Dataset loaded with {len(data)} samples") | |
| print("Columns:", list(data.columns)) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| row = self.data.iloc[idx] | |
| # Input: query | |
| input_text = row['query'] | |
| # Target: reference_information | |
| target_text = row['reference_information'] | |
| # Tokenize inputs | |
| input_encodings = self.tokenizer( | |
| input_text, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| # Tokenize targets | |
| target_encodings = self.tokenizer( | |
| target_text, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': input_encodings['input_ids'].squeeze(), | |
| 'attention_mask': input_encodings['attention_mask'].squeeze(), | |
| 'labels': target_encodings['input_ids'].squeeze() | |
| } | |
| def load_dataset(): | |
| """ | |
| Load the travel planning dataset from CSV. | |
| """ | |
| try: | |
| data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv") | |
| required_columns = ['query', 'reference_information'] | |
| for col in required_columns: | |
| if col not in data.columns: | |
| raise ValueError(f"Missing required column: {col}") | |
| print(f"Dataset loaded successfully with {len(data)} rows.") | |
| return data | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| sys.exit(1) | |
| def train_model(): | |
| try: | |
| # Load dataset | |
| data = load_dataset() | |
| # Initialize model and tokenizer | |
| print("Initializing T5 model and tokenizer...") | |
| tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False) | |
| model = T5ForConditionalGeneration.from_pretrained('t5-base') | |
| # Split data | |
| train_size = int(0.8 * len(data)) | |
| train_data = data[:train_size] | |
| val_data = data[train_size:] | |
| train_dataset = TravelDataset(train_data, tokenizer) | |
| val_dataset = TravelDataset(val_data, tokenizer) | |
| training_args = TrainingArguments( | |
| output_dir="./trained_travel_planner", | |
| num_train_epochs=3, | |
| per_device_train_batch_size=4, | |
| per_device_eval_batch_size=4, | |
| evaluation_strategy="steps", | |
| eval_steps=50, | |
| save_steps=100, | |
| weight_decay=0.01, | |
| logging_dir="./logs", | |
| logging_steps=10, | |
| load_best_model_at_end=True, | |
| ) | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer=tokenizer, | |
| model=model, | |
| padding=True | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| data_collator=data_collator | |
| ) | |
| print("Training model...") | |
| trainer.train() | |
| model.save_pretrained("./trained_travel_planner") | |
| tokenizer.save_pretrained("./trained_travel_planner") | |
| print("Model training complete!") | |
| return model, tokenizer | |
| except Exception as e: | |
| print(f"Training error: {e}") | |
| return None, None | |
| def generate_travel_plan(query, model, tokenizer): | |
| """ | |
| Generate a travel plan using the trained model. | |
| """ | |
| try: | |
| inputs = tokenizer( | |
| query, | |
| return_tensors="pt", | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| model = model.cuda() | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=512, | |
| num_beams=4, | |
| no_repeat_ngram_size=3, | |
| num_return_sequences=1 | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| except Exception as e: | |
| return f"Error generating travel plan: {e}" | |
| def main(): | |
| st.set_page_config( | |
| page_title="AI Travel Planner", | |
| page_icon="✈️", | |
| layout="wide" | |
| ) | |
| st.title("✈️ AI Travel Planner") | |
| # Sidebar to train model | |
| with st.sidebar: | |
| st.header("Model Management") | |
| if st.button("Retrain Model"): | |
| with st.spinner("Training the model..."): | |
| model, tokenizer = train_model() | |
| if model: | |
| st.session_state['model'] = model | |
| st.session_state['tokenizer'] = tokenizer | |
| st.success("Model retrained successfully!") | |
| else: | |
| st.error("Model retraining failed.") | |
| # Load model if not already loaded | |
| if 'model' not in st.session_state: | |
| with st.spinner("Loading model..."): | |
| model, tokenizer = train_model() | |
| st.session_state['model'] = model | |
| st.session_state['tokenizer'] = tokenizer | |
| # Input query | |
| st.subheader("Plan Your Trip") | |
| query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')") | |
| if st.button("Generate Plan"): | |
| if not query: | |
| st.error("Please enter a query.") | |
| else: | |
| with st.spinner("Generating your travel plan..."): | |
| travel_plan = generate_travel_plan( | |
| query, | |
| st.session_state['model'], | |
| st.session_state['tokenizer'] | |
| ) | |
| st.subheader("Your Travel Plan") | |
| st.write(travel_plan) | |
| if __name__ == "__main__": | |
| main() | |