TripAdvisor / app.py
morfriden's picture
Upload 4 files
124d1f9 verified
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
import faiss
import numpy as np
from transformers import pipeline
import time
import ast
import re
# --- 1. DATA LOADING AND INITIALIZATION ---
print("===== Application Startup =====")
start_time = time.time()
# Load the travel dataset and limit to the first 20,000 rows (same as index)
print("Loading TravelPlanner dataset...")
dataset = load_dataset("osunlp/TravelPlanner", "test")
print("Dataset ready.")
# --- 2. EMBEDDING AND RECOMMENDATION ENGINE ---
print("Loading embedding model...")
model_name = "all-mpnet-base-v2"
embedding_model = SentenceTransformer(f"sentence-transformers/{model_name}")
index_file = "trip_index.faiss"
print(f"Loading FAISS index from {index_file}...")
try:
index = faiss.read_index(index_file)
print(f"Index is ready. Total vectors in index: {index.ntotal}")
except RuntimeError:
print(f"Error: FAISS index file '{index_file}' not found.")
print("Please run the `build_index.py` script first to create the index.")
exit()
# --- 3. SYNTHETIC GENERATION ---
def format_plan_details(plan_string):
"""
Parses and formats the raw plan string from the dataset into readable Markdown.
"""
# If the plan is not in the expected dictionary format, return it as is.
if not plan_string or not plan_string.strip().startswith('['):
return plan_string
try:
# Safely parse the string representation of a list of dictionaries
plan_list = ast.literal_eval(plan_string)
except (ValueError, SyntaxError):
# If parsing fails, return the original string to avoid crashing
return plan_string
formatted_sections = []
for section in plan_list:
description = section.get('Description', 'Details')
content = section.get('Content', '').strip()
# Add a bold title for each section
formatted_sections.append(f"#### {description}")
# Use specific formatting based on the section's description
if any(keyword in description for keyword in ['Attractions', 'Restaurants', 'Accommodations', 'Flight']):
lines = content.split('\n')
if lines:
# Make the header bold
formatted_sections.append(f"**{lines[0]}**")
# Format the rest of the lines as a clean, bulleted list
for item in lines[1:]:
clean_item = ' '.join(item.split()) # Remove extra whitespace
if clean_item:
formatted_sections.append(f"- {clean_item}")
elif 'Self-driving' in description or 'Taxi' in description:
# Make simple travel descriptions more readable
mode_emoji = "🚗" if 'Self-driving' in description else "🚕"
formatted_sections.append(f"- {mode_emoji} {content.replace(', ', ', ')}")
else:
# Default formatting for any other type of content
formatted_sections.append(content)
# Add a newline for spacing between sections
formatted_sections.append("")
return "\n".join(formatted_sections)
def get_recommendations_and_generate(query_text, k=3):
# 1. Get Recommendations from existing data
query_vector = embedding_model.encode([query_text])
query_vector = np.array(query_vector, dtype=np.float32)
distances, indices = index.search(query_vector, k)
results = []
for idx_numpy in indices[0]:
idx = int(idx_numpy)
trip_plan = {
"dest": dataset['test']['dest'][idx],
"days": dataset['test']['days'][idx],
"reference_information": dataset['test']['reference_information'][idx]
}
results.append(trip_plan)
while len(results) < 3:
results.append({"dest": "No trip plan found", "days":"", "reference_information": ""})
# 2. Create a prompt for the generative model
prompt = f"Write a complete travel plan that includes a title and a day-by-day itinerary. The trip must be about: {query_text}."
print("Loading generative model...")
generator = pipeline('text-generation', model='gpt2')
# 3. Generate 10 new, creative trip ideas
print("Generating 10 synthetic trip ideas...")
generated_outputs = generator(
prompt,
max_new_tokens=250, # Increased tokens for more detailed plans
num_return_sequences=10,
pad_token_id=50256
)
# 4. Find the best trip out of the 10 generated
print("Finding the most relevant generated trip...")
generated_texts = [output['generated_text'].replace(prompt, "").strip() for output in generated_outputs]
# Embed all 10 generated texts
generated_embeddings = embedding_model.encode(generated_texts)
# Calculate cosine similarity between the user's query and each generated text
similarities = util.cos_sim(query_vector, generated_embeddings)
# Find the index of the most similar generated trip
best_recipe_index = np.argmax(similarities)
best_generated_trip = generated_texts[best_recipe_index]
return results[0], results[1], results[2], best_generated_trip
# --- 4. GRADIO USER INTERFACE ---
def format_trip_plan(trip):
# Formats the recommended trips with markdown
if not trip or 'reference_information' not in trip:
return "### No similar trip plan found."
formatted_plan = format_plan_details(trip['reference_information'])
return f"### {trip['days']}-days trip to {trip['dest'].upper()}\n**Suggested Plan:**\n{formatted_plan}"
def format_generated_trip(trip_text):
return trip_text
def trip_planner_wizard(destination, days):
# Combine user inputs into a single query for processing
days = int(days) # Ensure days is an integer for the f-string
query_text = f"a {days}-day trip to {destination}"
rec1, rec2, rec3, gen_rec_text = get_recommendations_and_generate(query_text)
return format_trip_plan(rec1), format_trip_plan(rec2), format_trip_plan(rec3), format_generated_trip(gen_rec_text)
end_time = time.time()
print(f"Models and data loaded in {end_time - start_time:.2f} seconds.")
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ✈️ TripPlanner AI")
gr.Markdown("Enter your destination and desired trip length, and get plan recommendations plus a new AI-generated idea!")
with gr.Row():
destination_input = gr.Textbox(label="Destination", placeholder="e.g., Paris")
days_input = gr.Number(label="Number of Days", value=3)
with gr.Row():
submit_btn = gr.Button("Get Trip Plans", variant="primary")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Recommended Trip Plans from Dataset")
output_rec1 = gr.Markdown()
output_rec2 = gr.Markdown()
output_rec3 = gr.Markdown()
with gr.Column(scale=1):
gr.Markdown("### ✨ New AI-Generated Idea")
output_gen = gr.Textbox(label="AI Generated Trip Plan", lines=20, interactive=False)
submit_btn.click(
fn=trip_planner_wizard,
inputs=[destination_input, days_input],
outputs=[output_rec1, output_rec2, output_rec3, output_gen]
)
gr.Examples(
examples=[
["Paris", 3],
["Orlando", 7],
["Tokyo", 5],
["the Greek Islands", 10]
],
inputs=[destination_input, days_input]
)
demo.launch(ssr_mode=False)