File size: 6,469 Bytes
7bdf2e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86a6db
7bdf2e1
 
 
 
 
 
a86a6db
7bdf2e1
 
 
 
 
 
 
 
 
a86a6db
 
 
 
7bdf2e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86a6db
7bdf2e1
 
 
 
a86a6db
7bdf2e1
 
 
 
a86a6db
7bdf2e1
 
 
 
 
 
 
 
 
a86a6db
7bdf2e1
 
 
 
a86a6db
 
7bdf2e1
 
 
a86a6db
7bdf2e1
 
a86a6db
7bdf2e1
a86a6db
7bdf2e1
 
 
 
 
 
a86a6db
 
 
7bdf2e1
 
a86a6db
7bdf2e1
 
 
 
 
a86a6db
7bdf2e1
 
 
 
 
a86a6db
7bdf2e1
a86a6db
 
7bdf2e1
a86a6db
 
 
 
 
7bdf2e1
 
a86a6db
7bdf2e1
 
a86a6db
7bdf2e1
a86a6db
7bdf2e1
 
 
a86a6db
7bdf2e1
 
 
 
 
a86a6db
7bdf2e1
 
 
a86a6db
7bdf2e1
 
 
 
 
 
 
a86a6db
 
7bdf2e1
a86a6db
7bdf2e1
 
 
 
 
 
 
 
a86a6db
 
7bdf2e1
 
 
a86a6db
7bdf2e1
a86a6db
7bdf2e1
 
a86a6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bdf2e1
a86a6db
 
7bdf2e1
 
a86a6db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import os
import sys
import torch
import pandas as pd
import streamlit as st
from datetime import datetime
from transformers import (
    T5ForConditionalGeneration, 
    T5Tokenizer,
    Trainer, 
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
import random

# Ensure reproducibility
torch.manual_seed(42)
random.seed(42)

# Environment setup
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

class TravelDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.data = data
        self.max_length = max_length

        print(f"Dataset loaded with {len(data)} samples")
        print("Columns:", list(data.columns))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Input: query
        input_text = row['query']
        # Target: reference_information
        target_text = row['reference_information']
        
        # Tokenize inputs
        input_encodings = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize targets
        target_encodings = self.tokenizer(
            target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': input_encodings['input_ids'].squeeze(),
            'attention_mask': input_encodings['attention_mask'].squeeze(),
            'labels': target_encodings['input_ids'].squeeze()
        }

def load_dataset():
    """
    Load the travel planning dataset from CSV.
    """
    try:
        data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv")
        
        required_columns = ['query', 'reference_information']
        for col in required_columns:
            if col not in data.columns:
                raise ValueError(f"Missing required column: {col}")
        
        print(f"Dataset loaded successfully with {len(data)} rows.")
        return data
    except Exception as e:
        print(f"Error loading dataset: {e}")
        sys.exit(1)

def train_model():
    try:
        # Load dataset
        data = load_dataset()

        # Initialize model and tokenizer
        print("Initializing T5 model and tokenizer...")
        tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False)
        model = T5ForConditionalGeneration.from_pretrained('t5-base')

        # Split data
        train_size = int(0.8 * len(data))
        train_data = data[:train_size]
        val_data = data[train_size:]

        train_dataset = TravelDataset(train_data, tokenizer)
        val_dataset = TravelDataset(val_data, tokenizer)

        training_args = TrainingArguments(
            output_dir="./trained_travel_planner",
            num_train_epochs=3,
            per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            evaluation_strategy="steps",
            eval_steps=50,
            save_steps=100,
            weight_decay=0.01,
            logging_dir="./logs",
            logging_steps=10,
            load_best_model_at_end=True,
        )

        data_collator = DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            model=model,
            padding=True
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator
        )

        print("Training model...")
        trainer.train()

        model.save_pretrained("./trained_travel_planner")
        tokenizer.save_pretrained("./trained_travel_planner")

        print("Model training complete!")
        return model, tokenizer
    except Exception as e:
        print(f"Training error: {e}")
        return None, None

def generate_travel_plan(query, model, tokenizer):
    """
    Generate a travel plan using the trained model.
    """
    try:
        inputs = tokenizer(
            query,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True
        )

        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
            model = model.cuda()

        outputs = model.generate(
            **inputs,
            max_length=512,
            num_beams=4,
            no_repeat_ngram_size=3,
            num_return_sequences=1
        )

        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error generating travel plan: {e}"

def main():
    st.set_page_config(
        page_title="AI Travel Planner",
        page_icon="✈️",
        layout="wide"
    )
    st.title("✈️ AI Travel Planner")

    # Sidebar to train model
    with st.sidebar:
        st.header("Model Management")
        if st.button("Retrain Model"):
            with st.spinner("Training the model..."):
                model, tokenizer = train_model()
                if model:
                    st.session_state['model'] = model
                    st.session_state['tokenizer'] = tokenizer
                    st.success("Model retrained successfully!")
                else:
                    st.error("Model retraining failed.")

    # Load model if not already loaded
    if 'model' not in st.session_state:
        with st.spinner("Loading model..."):
            model, tokenizer = train_model()
            st.session_state['model'] = model
            st.session_state['tokenizer'] = tokenizer

    # Input query
    st.subheader("Plan Your Trip")
    query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')")

    if st.button("Generate Plan"):
        if not query:
            st.error("Please enter a query.")
        else:
            with st.spinner("Generating your travel plan..."):
                travel_plan = generate_travel_plan(
                    query, 
                    st.session_state['model'], 
                    st.session_state['tokenizer']
                )
                st.subheader("Your Travel Plan")
                st.write(travel_plan)

if __name__ == "__main__":
    main()