Abdulla Fahem commited on
Commit
a86a6db
1 Parent(s): 23d2d4b

Add application file

Browse files
Files changed (1) hide show
  1. app.py +68 -304
app.py CHANGED
@@ -19,23 +19,14 @@ torch.manual_seed(42)
19
  random.seed(42)
20
 
21
  # Environment setup
22
- os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
23
 
24
  class TravelDataset(Dataset):
25
  def __init__(self, data, tokenizer, max_length=512):
26
- """
27
- Initialize the dataset for travel planning
28
-
29
- Parameters:
30
- - data: DataFrame containing travel planning data
31
- - tokenizer: Tokenizer for encoding input and output
32
- - max_length: Maximum sequence length
33
- """
34
  self.tokenizer = tokenizer
35
  self.data = data
36
  self.max_length = max_length
37
-
38
- # Print dataset information
39
  print(f"Dataset loaded with {len(data)} samples")
40
  print("Columns:", list(data.columns))
41
 
@@ -43,18 +34,12 @@ class TravelDataset(Dataset):
43
  return len(self.data)
44
 
45
  def __getitem__(self, idx):
46
- """
47
- Prepare an individual training sample
48
-
49
- Returns a dictionary with input_ids, attention_mask, and labels
50
- """
51
  row = self.data.iloc[idx]
52
 
53
- # Prepare input text
54
- input_text = self.format_input_text(row)
55
-
56
- # Prepare target text (travel plan)
57
- target_text = row['target']
58
 
59
  # Tokenize inputs
60
  input_encodings = self.tokenizer(
@@ -79,160 +64,100 @@ class TravelDataset(Dataset):
79
  'attention_mask': input_encodings['attention_mask'].squeeze(),
80
  'labels': target_encodings['input_ids'].squeeze()
81
  }
82
-
83
- @staticmethod
84
- def format_input_text(row):
85
- """
86
- Format input text for the model
87
-
88
- This method creates a prompt that the model will use to generate a travel plan
89
- """
90
- # Format the input text based on available columns
91
- destination = row.get('dest', 'Unknown')
92
- days = row.get('days', 3)
93
- budget = row.get('budget', 'Moderate')
94
- interests = row.get('interests', 'Culture, Food')
95
-
96
- return f"Plan a trip to {destination} for {days} days with a {budget} budget. Include activities related to: {interests}"
97
 
98
  def load_dataset():
99
  """
100
- Load the travel planning dataset from HuggingFace
101
-
102
- Returns:
103
- - pandas DataFrame with the dataset
104
  """
105
  try:
106
- # Load dataset from CSV
107
  data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv")
108
 
109
- # Basic data validation
110
- required_columns = ['dest', 'days', 'budget', 'interests', 'target']
111
  for col in required_columns:
112
  if col not in data.columns:
113
  raise ValueError(f"Missing required column: {col}")
114
 
115
- # Print dataset info
116
- print("Dataset successfully loaded")
117
- print(f"Total samples: {len(data)}")
118
- print("Columns:", list(data.columns))
119
-
120
  return data
121
  except Exception as e:
122
  print(f"Error loading dataset: {e}")
123
  sys.exit(1)
124
 
125
  def train_model():
126
- """
127
- Train the T5 model for travel planning
128
-
129
- Returns:
130
- - Trained model
131
- - Tokenizer
132
- """
133
  try:
134
  # Load dataset
135
  data = load_dataset()
136
-
137
  # Initialize model and tokenizer
138
  print("Initializing T5 model and tokenizer...")
139
  tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False)
140
  model = T5ForConditionalGeneration.from_pretrained('t5-base')
141
-
142
- # Split data into training and validation sets
143
  train_size = int(0.8 * len(data))
144
  train_data = data[:train_size]
145
  val_data = data[train_size:]
146
-
147
- print(f"Training set size: {len(train_data)}")
148
- print(f"Validation set size: {len(val_data)}")
149
-
150
- # Create datasets
151
  train_dataset = TravelDataset(train_data, tokenizer)
152
  val_dataset = TravelDataset(val_data, tokenizer)
153
-
154
- # Training arguments
155
  training_args = TrainingArguments(
156
- output_dir=f"./travel_planner_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
157
  num_train_epochs=3,
158
  per_device_train_batch_size=4,
159
  per_device_eval_batch_size=4,
160
- warmup_steps=500,
161
- weight_decay=0.01,
162
- logging_dir="./logs",
163
- logging_steps=10,
164
  evaluation_strategy="steps",
165
  eval_steps=50,
166
  save_steps=100,
 
 
 
167
  load_best_model_at_end=True,
168
  )
169
-
170
- # Data collator
171
  data_collator = DataCollatorForSeq2Seq(
172
  tokenizer=tokenizer,
173
  model=model,
174
  padding=True
175
  )
176
-
177
- # Initialize trainer
178
  trainer = Trainer(
179
  model=model,
180
  args=training_args,
181
  train_dataset=train_dataset,
182
  eval_dataset=val_dataset,
183
- data_collator=data_collator,
184
  )
185
-
186
- # Train the model
187
- print("Starting model training...")
188
  trainer.train()
189
-
190
- # Save the model and tokenizer
191
- model_path = "./trained_travel_planner"
192
- model.save_pretrained(model_path)
193
- tokenizer.save_pretrained(model_path)
194
-
195
- print("Model training completed and saved!")
196
  return model, tokenizer
197
-
198
  except Exception as e:
199
- print(f"Error during model training: {str(e)}")
200
  return None, None
201
 
202
- def generate_travel_plan(destination, days, interests, budget, model, tokenizer):
203
  """
204
- Generate a travel plan using the trained model
205
-
206
- Parameters:
207
- - destination: Travel destination
208
- - days: Trip duration
209
- - interests: User's interests
210
- - budget: Trip budget level
211
- - model: Trained T5 model
212
- - tokenizer: Model tokenizer
213
-
214
- Returns:
215
- - Generated travel plan
216
  """
217
  try:
218
- # Format input prompt
219
- prompt = f"Plan a trip to {destination} for {days} days with a {budget} budget. Include activities related to: {', '.join(interests)}"
220
-
221
- # Tokenize input
222
  inputs = tokenizer(
223
- prompt,
224
  return_tensors="pt",
225
  max_length=512,
226
  padding="max_length",
227
  truncation=True
228
  )
229
-
230
- # Move to GPU if available
231
  if torch.cuda.is_available():
232
  inputs = {k: v.cuda() for k, v in inputs.items()}
233
  model = model.cuda()
234
-
235
- # Generate output
236
  outputs = model.generate(
237
  **inputs,
238
  max_length=512,
@@ -240,14 +165,10 @@ def generate_travel_plan(destination, days, interests, budget, model, tokenizer)
240
  no_repeat_ngram_size=3,
241
  num_return_sequences=1
242
  )
243
-
244
- # Decode and return the travel plan
245
- travel_plan = tokenizer.decode(outputs[0], skip_special_tokens=True)
246
- return travel_plan
247
-
248
  except Exception as e:
249
- print(f"Error generating travel plan: {e}")
250
- return "Could not generate travel plan."
251
 
252
  def main():
253
  st.set_page_config(
@@ -255,201 +176,44 @@ def main():
255
  page_icon="✈️",
256
  layout="wide"
257
  )
258
-
259
  st.title("✈️ AI Travel Planner")
260
- st.markdown("### Plan your perfect trip with AI assistance!")
261
-
262
- # Add training button in sidebar only
263
  with st.sidebar:
264
  st.header("Model Management")
265
  if st.button("Retrain Model"):
266
- with st.spinner("Training new model... This will take a while..."):
267
  model, tokenizer = train_model()
268
- if model is not None:
269
  st.session_state['model'] = model
270
  st.session_state['tokenizer'] = tokenizer
271
- st.success("Model training completed!")
272
-
273
- # Add model information
274
- st.markdown("### Model Information")
275
- if 'model' in st.session_state:
276
- st.success("✓ Model loaded")
277
- st.info("""
278
- This model was trained on travel plans for:
279
- - Destinations from HuggingFace dataset
280
- - Flexible days duration
281
- - Multiple budget levels
282
- - Various interest combinations
283
- """)
284
-
285
- # Load or train model
286
- if 'model' not in st.session_state:
287
- with st.spinner("Loading AI model... Please wait..."):
288
- model, tokenizer = train_model() # Changed from load_or_train_model
289
- if model is None or tokenizer is None:
290
- st.error("Failed to load/train the AI model. Please try again.")
291
- return
292
- st.session_state.model = model
293
- st.session_state.tokenizer = tokenizer
294
-
295
- # Create two columns for input form
296
- col1, col2 = st.columns([2, 1])
297
-
298
- with col1:
299
- # Input form in a card-like container
300
- with st.container():
301
- st.markdown("### 🎯 Plan Your Trip")
302
-
303
- # Destination and Duration row
304
- dest_col, days_col = st.columns(2)
305
- with dest_col:
306
- destination = st.text_input(
307
- "🌍 Destination",
308
- placeholder="e.g., Paris, Tokyo, New York...",
309
- help="Enter the city you want to visit"
310
- )
311
-
312
- with days_col:
313
- days = st.slider(
314
- "📅 Number of days",
315
- min_value=1,
316
- max_value=14,
317
- value=3,
318
- help="Select the duration of your trip"
319
- )
320
-
321
- # Budget and Interests row
322
- budget_col, interests_col = st.columns(2)
323
- with budget_col:
324
- budget = st.selectbox(
325
- "💰 Budget Level",
326
- ["Budget", "Moderate", "Luxury"],
327
- help="Select your preferred budget level"
328
- )
329
-
330
- with interests_col:
331
- interests = st.multiselect(
332
- "🎯 Interests",
333
- ["Culture", "History", "Food", "Nature", "Shopping",
334
- "Adventure", "Relaxation", "Art", "Museums"],
335
- ["Culture", "Food"],
336
- help="Select up to three interests to personalize your plan"
337
  )
338
-
339
- with col2:
340
- # Tips and information
341
- st.markdown("### 💡 Travel Tips")
342
- st.info("""
343
- - Choose up to 3 interests for best results
344
- - Consider your travel season
345
- - Budget levels affect activity suggestions
346
- - Plans are customizable after generation
347
- """)
348
-
349
- # Generate button centered
350
- col1, col2, col3 = st.columns([1, 2, 1])
351
- with col2:
352
- generate_button = st.button(
353
- "🎨 Generate Travel Plan",
354
- type="primary",
355
- use_container_width=True
356
- )
357
-
358
- if generate_button:
359
- if not destination:
360
- st.error("Please enter a destination!")
361
- return
362
-
363
- if not interests:
364
- st.error("Please select at least one interest!")
365
- return
366
-
367
- if len(interests) > 3:
368
- st.warning("For best results, please select up to 3 interests.")
369
-
370
- with st.spinner("🤖 Creating your personalized travel plan..."):
371
- travel_plan = generate_travel_plan(
372
- destination,
373
- days,
374
- interests,
375
- budget,
376
- st.session_state.model,
377
- st.session_state.tokenizer
378
- )
379
-
380
- # Create an expander for the success message with trip overview
381
- with st.expander("✨ Your travel plan is ready! Click to see trip overview", expanded=True):
382
- col1, col2, col3 = st.columns(3)
383
- with col1:
384
- st.metric("Destination", destination)
385
- with col2:
386
- if days == 1:
387
- st.metric("Duration", f"{days} day")
388
- else:
389
- st.metric("Duration", f"{days} days")
390
- with col3:
391
- st.metric("Budget", budget)
392
-
393
- st.write("**Selected Interests:**", ", ".join(interests))
394
-
395
- # Display the plan in tabs with improved styling
396
- plan_tab, summary_tab = st.tabs(["📋 Detailed Itinerary", "ℹ️ Trip Summary"])
397
-
398
- with plan_tab:
399
- # Add a container for better spacing
400
- with st.container():
401
- # Add trip title
402
- st.markdown(f"## 🌍 {days}-Day Trip to {destination}")
403
- st.markdown("---")
404
-
405
- # Display the formatted plan
406
- st.markdown(travel_plan)
407
-
408
- # Add export options in a nice container
409
- with st.container():
410
- st.markdown("---")
411
- col1, col2 = st.columns([1, 4])
412
- with col1:
413
- st.download_button(
414
- label="📥 Download Plan",
415
- data=travel_plan,
416
- file_name=f"travel_plan_{destination.lower().replace(' ', '_')}.md",
417
- mime="text/markdown",
418
- use_container_width=True
419
- )
420
-
421
- with summary_tab:
422
- # Create three columns for summary information with cards
423
- with st.container():
424
- st.markdown("## Trip Overview")
425
- sum_col1, sum_col2, sum_col3 = st.columns(3)
426
-
427
- with sum_col1:
428
- with st.container():
429
- st.markdown("### 📍 Destination Details")
430
- st.markdown(f"**Location:** {destination}")
431
- if days == 1:
432
- st.markdown(f"**Duration:** {days} day")
433
- else:
434
- st.markdown(f"**Duration:** {days} days")
435
- st.markdown(f"**Budget Level:** {budget}")
436
-
437
- with sum_col2:
438
- with st.container():
439
- st.markdown("### 🎯 Trip Focus")
440
- st.markdown("**Selected Interests:**")
441
- for interest in interests:
442
- st.markdown(f"- {interest}")
443
-
444
- with sum_col3:
445
- with st.container():
446
- st.markdown("### ⚠️ Travel Tips")
447
- st.info(
448
- "• Verify opening hours\n"
449
- "• Check current prices\n"
450
- "• Confirm availability\n"
451
- "• Consider seasonal factors"
452
- )
453
 
454
  if __name__ == "__main__":
455
- main()
 
19
  random.seed(42)
20
 
21
  # Environment setup
22
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
23
 
24
  class TravelDataset(Dataset):
25
  def __init__(self, data, tokenizer, max_length=512):
 
 
 
 
 
 
 
 
26
  self.tokenizer = tokenizer
27
  self.data = data
28
  self.max_length = max_length
29
+
 
30
  print(f"Dataset loaded with {len(data)} samples")
31
  print("Columns:", list(data.columns))
32
 
 
34
  return len(self.data)
35
 
36
  def __getitem__(self, idx):
 
 
 
 
 
37
  row = self.data.iloc[idx]
38
 
39
+ # Input: query
40
+ input_text = row['query']
41
+ # Target: reference_information
42
+ target_text = row['reference_information']
 
43
 
44
  # Tokenize inputs
45
  input_encodings = self.tokenizer(
 
64
  'attention_mask': input_encodings['attention_mask'].squeeze(),
65
  'labels': target_encodings['input_ids'].squeeze()
66
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def load_dataset():
69
  """
70
+ Load the travel planning dataset from CSV.
 
 
 
71
  """
72
  try:
 
73
  data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv")
74
 
75
+ required_columns = ['query', 'reference_information']
 
76
  for col in required_columns:
77
  if col not in data.columns:
78
  raise ValueError(f"Missing required column: {col}")
79
 
80
+ print(f"Dataset loaded successfully with {len(data)} rows.")
 
 
 
 
81
  return data
82
  except Exception as e:
83
  print(f"Error loading dataset: {e}")
84
  sys.exit(1)
85
 
86
  def train_model():
 
 
 
 
 
 
 
87
  try:
88
  # Load dataset
89
  data = load_dataset()
90
+
91
  # Initialize model and tokenizer
92
  print("Initializing T5 model and tokenizer...")
93
  tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False)
94
  model = T5ForConditionalGeneration.from_pretrained('t5-base')
95
+
96
+ # Split data
97
  train_size = int(0.8 * len(data))
98
  train_data = data[:train_size]
99
  val_data = data[train_size:]
100
+
 
 
 
 
101
  train_dataset = TravelDataset(train_data, tokenizer)
102
  val_dataset = TravelDataset(val_data, tokenizer)
103
+
 
104
  training_args = TrainingArguments(
105
+ output_dir="./trained_travel_planner",
106
  num_train_epochs=3,
107
  per_device_train_batch_size=4,
108
  per_device_eval_batch_size=4,
 
 
 
 
109
  evaluation_strategy="steps",
110
  eval_steps=50,
111
  save_steps=100,
112
+ weight_decay=0.01,
113
+ logging_dir="./logs",
114
+ logging_steps=10,
115
  load_best_model_at_end=True,
116
  )
117
+
 
118
  data_collator = DataCollatorForSeq2Seq(
119
  tokenizer=tokenizer,
120
  model=model,
121
  padding=True
122
  )
123
+
 
124
  trainer = Trainer(
125
  model=model,
126
  args=training_args,
127
  train_dataset=train_dataset,
128
  eval_dataset=val_dataset,
129
+ data_collator=data_collator
130
  )
131
+
132
+ print("Training model...")
 
133
  trainer.train()
134
+
135
+ model.save_pretrained("./trained_travel_planner")
136
+ tokenizer.save_pretrained("./trained_travel_planner")
137
+
138
+ print("Model training complete!")
 
 
139
  return model, tokenizer
 
140
  except Exception as e:
141
+ print(f"Training error: {e}")
142
  return None, None
143
 
144
+ def generate_travel_plan(query, model, tokenizer):
145
  """
146
+ Generate a travel plan using the trained model.
 
 
 
 
 
 
 
 
 
 
 
147
  """
148
  try:
 
 
 
 
149
  inputs = tokenizer(
150
+ query,
151
  return_tensors="pt",
152
  max_length=512,
153
  padding="max_length",
154
  truncation=True
155
  )
156
+
 
157
  if torch.cuda.is_available():
158
  inputs = {k: v.cuda() for k, v in inputs.items()}
159
  model = model.cuda()
160
+
 
161
  outputs = model.generate(
162
  **inputs,
163
  max_length=512,
 
165
  no_repeat_ngram_size=3,
166
  num_return_sequences=1
167
  )
168
+
169
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
170
  except Exception as e:
171
+ return f"Error generating travel plan: {e}"
 
172
 
173
  def main():
174
  st.set_page_config(
 
176
  page_icon="✈️",
177
  layout="wide"
178
  )
 
179
  st.title("✈️ AI Travel Planner")
180
+
181
+ # Sidebar to train model
 
182
  with st.sidebar:
183
  st.header("Model Management")
184
  if st.button("Retrain Model"):
185
+ with st.spinner("Training the model..."):
186
  model, tokenizer = train_model()
187
+ if model:
188
  st.session_state['model'] = model
189
  st.session_state['tokenizer'] = tokenizer
190
+ st.success("Model retrained successfully!")
191
+ else:
192
+ st.error("Model retraining failed.")
193
+
194
+ # Load model if not already loaded
195
+ if 'model' not in st.session_state:
196
+ with st.spinner("Loading model..."):
197
+ model, tokenizer = train_model()
198
+ st.session_state['model'] = model
199
+ st.session_state['tokenizer'] = tokenizer
200
+
201
+ # Input query
202
+ st.subheader("Plan Your Trip")
203
+ query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')")
204
+
205
+ if st.button("Generate Plan"):
206
+ if not query:
207
+ st.error("Please enter a query.")
208
+ else:
209
+ with st.spinner("Generating your travel plan..."):
210
+ travel_plan = generate_travel_plan(
211
+ query,
212
+ st.session_state['model'],
213
+ st.session_state['tokenizer']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
+ st.subheader("Your Travel Plan")
216
+ st.write(travel_plan)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  if __name__ == "__main__":
219
+ main()