Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer, util | |
import faiss | |
import numpy as np | |
from transformers import pipeline | |
import time | |
import ast | |
import re | |
# --- 1. DATA LOADING AND INITIALIZATION --- | |
print("===== Application Startup =====") | |
start_time = time.time() | |
# Load the travel dataset and limit to the first 20,000 rows (same as index) | |
print("Loading TravelPlanner dataset...") | |
dataset = load_dataset("osunlp/TravelPlanner", "test") | |
print("Dataset ready.") | |
# --- 2. EMBEDDING AND RECOMMENDATION ENGINE --- | |
print("Loading embedding model...") | |
model_name = "all-mpnet-base-v2" | |
embedding_model = SentenceTransformer(f"sentence-transformers/{model_name}") | |
index_file = "trip_index.faiss" | |
print(f"Loading FAISS index from {index_file}...") | |
try: | |
index = faiss.read_index(index_file) | |
print(f"Index is ready. Total vectors in index: {index.ntotal}") | |
except RuntimeError: | |
print(f"Error: FAISS index file '{index_file}' not found.") | |
print("Please run the `build_index.py` script first to create the index.") | |
exit() | |
# --- 3. SYNTHETIC GENERATION --- | |
def format_plan_details(plan_string): | |
""" | |
Parses and formats the raw plan string from the dataset into readable Markdown. | |
""" | |
# If the plan is not in the expected dictionary format, return it as is. | |
if not plan_string or not plan_string.strip().startswith('['): | |
return plan_string | |
try: | |
# Safely parse the string representation of a list of dictionaries | |
plan_list = ast.literal_eval(plan_string) | |
except (ValueError, SyntaxError): | |
# If parsing fails, return the original string to avoid crashing | |
return plan_string | |
formatted_sections = [] | |
for section in plan_list: | |
description = section.get('Description', 'Details') | |
content = section.get('Content', '').strip() | |
# Add a bold title for each section | |
formatted_sections.append(f"#### {description}") | |
# Use specific formatting based on the section's description | |
if any(keyword in description for keyword in ['Attractions', 'Restaurants', 'Accommodations', 'Flight']): | |
lines = content.split('\n') | |
if lines: | |
# Make the header bold | |
formatted_sections.append(f"**{lines[0]}**") | |
# Format the rest of the lines as a clean, bulleted list | |
for item in lines[1:]: | |
clean_item = ' '.join(item.split()) # Remove extra whitespace | |
if clean_item: | |
formatted_sections.append(f"- {clean_item}") | |
elif 'Self-driving' in description or 'Taxi' in description: | |
# Make simple travel descriptions more readable | |
mode_emoji = "🚗" if 'Self-driving' in description else "🚕" | |
formatted_sections.append(f"- {mode_emoji} {content.replace(', ', ', ')}") | |
else: | |
# Default formatting for any other type of content | |
formatted_sections.append(content) | |
# Add a newline for spacing between sections | |
formatted_sections.append("") | |
return "\n".join(formatted_sections) | |
def get_recommendations_and_generate(query_text, k=3): | |
# 1. Get Recommendations from existing data | |
query_vector = embedding_model.encode([query_text]) | |
query_vector = np.array(query_vector, dtype=np.float32) | |
distances, indices = index.search(query_vector, k) | |
results = [] | |
for idx_numpy in indices[0]: | |
idx = int(idx_numpy) | |
trip_plan = { | |
"dest": dataset['test']['dest'][idx], | |
"days": dataset['test']['days'][idx], | |
"reference_information": dataset['test']['reference_information'][idx] | |
} | |
results.append(trip_plan) | |
while len(results) < 3: | |
results.append({"dest": "No trip plan found", "days":"", "reference_information": ""}) | |
# 2. Create a prompt for the generative model | |
prompt = f"Write a complete travel plan that includes a title and a day-by-day itinerary. The trip must be about: {query_text}." | |
print("Loading generative model...") | |
generator = pipeline('text-generation', model='gpt2') | |
# 3. Generate 10 new, creative trip ideas | |
print("Generating 10 synthetic trip ideas...") | |
generated_outputs = generator( | |
prompt, | |
max_new_tokens=250, # Increased tokens for more detailed plans | |
num_return_sequences=10, | |
pad_token_id=50256 | |
) | |
# 4. Find the best trip out of the 10 generated | |
print("Finding the most relevant generated trip...") | |
generated_texts = [output['generated_text'].replace(prompt, "").strip() for output in generated_outputs] | |
# Embed all 10 generated texts | |
generated_embeddings = embedding_model.encode(generated_texts) | |
# Calculate cosine similarity between the user's query and each generated text | |
similarities = util.cos_sim(query_vector, generated_embeddings) | |
# Find the index of the most similar generated trip | |
best_recipe_index = np.argmax(similarities) | |
best_generated_trip = generated_texts[best_recipe_index] | |
return results[0], results[1], results[2], best_generated_trip | |
# --- 4. GRADIO USER INTERFACE --- | |
def format_trip_plan(trip): | |
# Formats the recommended trips with markdown | |
if not trip or 'reference_information' not in trip: | |
return "### No similar trip plan found." | |
formatted_plan = format_plan_details(trip['reference_information']) | |
return f"### {trip['days']}-days trip to {trip['dest'].upper()}\n**Suggested Plan:**\n{formatted_plan}" | |
def format_generated_trip(trip_text): | |
return trip_text | |
def trip_planner_wizard(destination, days): | |
# Combine user inputs into a single query for processing | |
days = int(days) # Ensure days is an integer for the f-string | |
query_text = f"a {days}-day trip to {destination}" | |
rec1, rec2, rec3, gen_rec_text = get_recommendations_and_generate(query_text) | |
return format_trip_plan(rec1), format_trip_plan(rec2), format_trip_plan(rec3), format_generated_trip(gen_rec_text) | |
end_time = time.time() | |
print(f"Models and data loaded in {end_time - start_time:.2f} seconds.") | |
# Gradio Interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# ✈️ TripPlanner AI") | |
gr.Markdown("Enter your destination and desired trip length, and get plan recommendations plus a new AI-generated idea!") | |
with gr.Row(): | |
destination_input = gr.Textbox(label="Destination", placeholder="e.g., Paris") | |
days_input = gr.Number(label="Number of Days", value=3) | |
with gr.Row(): | |
submit_btn = gr.Button("Get Trip Plans", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("### Recommended Trip Plans from Dataset") | |
output_rec1 = gr.Markdown() | |
output_rec2 = gr.Markdown() | |
output_rec3 = gr.Markdown() | |
with gr.Column(scale=1): | |
gr.Markdown("### ✨ New AI-Generated Idea") | |
output_gen = gr.Textbox(label="AI Generated Trip Plan", lines=20, interactive=False) | |
submit_btn.click( | |
fn=trip_planner_wizard, | |
inputs=[destination_input, days_input], | |
outputs=[output_rec1, output_rec2, output_rec3, output_gen] | |
) | |
gr.Examples( | |
examples=[ | |
["Paris", 3], | |
["Orlando", 7], | |
["Tokyo", 5], | |
["the Greek Islands", 10] | |
], | |
inputs=[destination_input, days_input] | |
) | |
demo.launch(ssr_mode=False) | |