Spaces:
Sleeping
Sleeping
# @title Default title text | |
import gradio as gr | |
import feedparser | |
from bs4 import BeautifulSoup | |
from datetime import datetime, timedelta | |
import pytz | |
from typing import List, Dict, Tuple | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
from dateutil.parser import parse as dateutil_parse | |
from dateutil.parser import ParserError | |
import os | |
import json | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
#from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.chains import LLMChain | |
from langchain_huggingface import HuggingFacePipeline | |
from dateutil import parser | |
from langchain.embeddings import HuggingFaceEmbeddings | |
class BrockEventsRAG: | |
def __init__(self): | |
"""Initialize the RAG system with improved data handling""" | |
print("Initializing RAG system...") | |
#Slider Initialization for interface | |
self.temperature = 0.7 | |
self.top_p = 0.95 | |
self.top_k = 50 | |
# Initialize embedding function | |
self.emodel_name="multi-qa-MiniLM-L6-cos-v1" | |
self.embedding_function = SentenceTransformerEmbeddingFunction(self.emodel_name) | |
#self.embeddings = HuggingFaceEmbeddings(model_name) | |
# Setup ChromaDB with in-memory client for Colab | |
try: | |
# First try in-memory client | |
self.chroma_client = chromadb.Client() | |
print("Using in-memory ChromaDB client") | |
except Exception as e: | |
print(f"Error with in-memory client: {e}") | |
# Fallback to persistent client with temporary directory | |
import tempfile | |
temp_dir = tempfile.mkdtemp() | |
print(f"Using temporary directory: {temp_dir}") | |
self.chroma_client = chromadb.PersistentClient(path=temp_dir) | |
# Create collection with retry logic | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
self.collection = self.chroma_client.get_or_create_collection( | |
name="brock_events", | |
embedding_function=self.embedding_function, | |
metadata={"hnsw:space": "cosine"} | |
) | |
print("Successfully created collection") | |
break | |
except Exception as e: | |
print(f"Attempt {attempt + 1} failed: {e}") | |
if attempt == max_retries - 1: | |
raise | |
# Setup date handling | |
self.eastern = pytz.timezone('America/New_York') | |
self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0) | |
self.date_range_end = self.today + timedelta(days=25) | |
# Initialize LLM components | |
self.setup_llm() | |
# Load initial events | |
self.update_database() | |
#Prompt Template for LLM and RAG | |
RESPONSE_TEMPLATE = """You are a helpful Brock University events assistant. | |
Create an engaging opening line to get students excited about events related to this query: | |
Query: {query} | |
Guidelines: | |
- Be friendly and enthusiastic | |
- Match the tone to the type of event | |
- Keep it brief but engaging | |
Examples: | |
- Query: Are there any business networking events coming up? | |
Introduction: "Get ready to connect! We've got some exciting business networking opportunities coming soon." | |
- Query: What workshops are happening next week? | |
Introduction: "Boost your skills! Check out these awesome workshops happening next week." | |
""" | |
def setup_llm(self): | |
"""Setup LLM pipeline and chain""" | |
try: | |
print("Setting up LLM components...") | |
# Using a more powerful model | |
self.model_name = "google/flan-t5-base" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.llm_model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
hf_pipeline = pipeline( | |
task="text2text-generation", | |
model=self.llm_model, | |
tokenizer=self.tokenizer, | |
do_sample=True, | |
temperature=self.temperature, # Increased for more creative responses | |
top_k=self.top_k, # Reduced to be more focused | |
top_p=self.top_p, # Slightly reduced for more focused sampling | |
max_length=50, # Reduced to force more concise responses | |
min_length=10, # Reduced minimum length | |
repetition_penalty=1.5, # Increased to more strongly prevent repetition | |
no_repeat_ngram_size=3 # Prevent repeating phrases of 3 or more tokens | |
) | |
# Initialize the LLM | |
self.llm = HuggingFacePipeline(pipeline=hf_pipeline) | |
# Set up vector store and retriever | |
self.vectorstore = Chroma( | |
client=self.chroma_client, | |
collection_name="brock_events", | |
embedding_function=self.embedding_function | |
) | |
# Configure retriever with appropriate parameters | |
self.retriever = self.vectorstore.as_retriever( | |
search_kwargs={ | |
"k": 3, | |
#"fetch_k": 6 # Fetch more candidates than needed | |
} | |
) | |
# Create the prompt | |
self.prompt = ChatPromptTemplate.from_template(self.RESPONSE_TEMPLATE) | |
# Create the RAG chain | |
self.chain = ( | |
{"context": self.retriever, "question": RunnablePassthrough()} | |
| self.prompt | |
| self.llm | |
| StrOutputParser() | |
) | |
print("LLM setup completed successfully") | |
except Exception as e: | |
print(f"Error setting up LLM: {e}") | |
import traceback | |
print(f"Full error: {traceback.format_exc()}") | |
raise # Re-raise the exception to handle it in the calling code | |
def fetch_rss_feed(self) -> List[Dict]: | |
"""Fetch and parse RSS feed""" | |
url = "https://experiencebu.brocku.ca/events.rss" | |
try: | |
feed = feedparser.parse(url) | |
print(f"Fetched {len(feed.entries)} entries from feed") | |
return feed.entries | |
except Exception as e: | |
print(f"Error fetching RSS feed: {e}") | |
return [] | |
def process_event(self, entry) -> Dict: | |
"""Process a single event entry with proper date handling""" | |
try: | |
# Parse start time | |
try: | |
if 'start' in entry: | |
start = dateutil_parse(entry.start) | |
elif 'published_parsed' in entry: | |
start = datetime(*entry.published_parsed[:6]) | |
else: | |
# Try to parse from description HTML | |
soup = BeautifulSoup(entry.description, 'html.parser') | |
time_elem = soup.find('time', class_='dt-start') | |
if time_elem and 'datetime' in time_elem.attrs: | |
start = dateutil_parse(time_elem['datetime']) | |
else: | |
print(f"No valid date found for event: {entry.get('title', 'Unknown')}") | |
return None | |
except (ParserError, ValueError) as e: | |
print(f"Error parsing date for event {entry.get('title', 'Unknown')}: {e}") | |
return None | |
# Convert to eastern timezone | |
if not start.tzinfo: | |
start = self.eastern.localize(start) | |
# Skip if outside date range | |
if not (self.today <= start <= self.date_range_end): | |
return None | |
# Extract location | |
location = entry.get('location', 'Location not specified') | |
if not location: | |
soup = BeautifulSoup(entry.description, 'html.parser') | |
location_elem = soup.find('span', class_='p-location') | |
if location_elem: | |
location = location_elem.get_text().strip() | |
# Clean description | |
description = BeautifulSoup(entry.description, 'html.parser').get_text().strip() | |
return { | |
"title": entry.title, | |
"location": location, | |
"start": start.isoformat(), # Store as string | |
"description": description, | |
"link": entry.link | |
} | |
except Exception as e: | |
print(f"Error processing event: {e}") | |
return None | |
def update_database(self): | |
"""Update the database with new events""" | |
try: | |
print("Starting database update...") | |
entries = self.fetch_rss_feed() | |
if not entries: | |
print("No entries fetched from RSS feed") | |
return | |
print(f"Processing {len(entries)} entries...") | |
new_events = [] | |
# Delete and recreate collection | |
self.chroma_client.delete_collection("brock_events") | |
self.collection = self.chroma_client.create_collection( | |
name="brock_events", | |
embedding_function=self.embedding_function, | |
metadata={"hnsw:space": "cosine"} | |
) | |
# Process each entry | |
for entry in entries: | |
event = self.process_event(entry) | |
if event: # Only add if event processing was successful | |
new_events.append(event) | |
if new_events: | |
print(f"\nAdding {len(new_events)} events to database...") | |
for i, event in enumerate(new_events): | |
try: | |
# Use the already formatted event text | |
event_text = event['text'] | |
print(f"\nAdding event {i+1}/{len(new_events)}") | |
print("Event text sample:", event_text[:200]) | |
# Create unique ID using the event's ID or index | |
unique_id = event['id'] or f"event_{i}_{datetime.now().timestamp()}" | |
# Add to collection with metadata | |
self.collection.add( | |
documents=[event_text], | |
ids=[unique_id], | |
metadatas=[event['metadata']] | |
) | |
print(f"Successfully added event {i+1}") | |
except Exception as e: | |
print(f"Error adding event {i+1}: {e}") | |
import traceback | |
print(f"Full error trace for event {i+1}: {traceback.format_exc()}") | |
continue # Continue with next event even if this one fails | |
print(f"\nSuccessfully added {len(new_events)} events to the database") | |
except Exception as e: | |
print(f"Error updating database: {e}") | |
import traceback | |
print(f"Full error: {traceback.format_exc()}") | |
def query_events(self, query: str) -> str: | |
"""Query events using semantic search with category-specific enhancement""" | |
try: | |
print(f"\nProcessing query: {query}") | |
collection_count = self.collection.count() | |
print(f"Current collection size: {collection_count} documents") | |
if collection_count == 0: | |
return "No events are currently loaded in the database. Please try again later." | |
# Define category-specific terms | |
query_lower = query.lower() | |
enhanced_query = query | |
# Category-specific query enhancement | |
if 'makerspace' in query_lower: | |
enhanced_query = f"{query} maker making create creative workshop lab hands-on" | |
elif 'math' in query_lower or 'science' in query_lower: | |
enhanced_query = f"{query} mathematics physics chemistry biology research laboratory" | |
elif 'business' in query_lower or 'networking' in query_lower: | |
enhanced_query = f"{query} business networking professional career development" | |
elif 'career' in query_lower or 'job' in query_lower: | |
enhanced_query = f"{query} career employment job fair hiring recruitment" | |
# Query the collection | |
results = self.collection.query( | |
query_texts=[enhanced_query], | |
n_results=5, | |
include=['documents', 'metadatas'] | |
) | |
if not results or not results['documents'] or not results['documents'][0]: | |
return "I couldn't find any events matching your query." | |
# Format responses based on query type | |
events_found = [] | |
for doc, metadata in zip(results['documents'][0], results['metadatas'][0]): | |
# Define relevancy based on query type | |
if 'makerspace' in query_lower: | |
is_relevant = any(term in doc.lower() for term in | |
['makerspace', 'maker', 'create', 'workshop', 'lab']) | |
elif 'math' in query_lower or 'science' in query_lower: | |
is_relevant = any(term in doc.lower() for term in | |
['math', 'science', 'physics', 'chemistry', 'biology', 'research']) | |
elif 'business' in query_lower or 'networking' in query_lower: | |
is_relevant = any(term in doc.lower() for term in | |
['business', 'network', 'professional', 'entrepreneur']) | |
elif 'career' in query_lower or 'job' in query_lower: | |
is_relevant = any(term in doc.lower() for term in | |
['career', 'job', 'employment', 'hiring', 'fair']) | |
else: | |
is_relevant = True # For general queries, show all events | |
if is_relevant: | |
# Add appropriate emoji based on event type | |
emoji = "π " # Default emoji | |
if "workshop" in doc.lower(): | |
emoji = "π§" | |
elif "makerspace" in doc.lower(): | |
emoji = "π οΈ" | |
elif "career" in doc.lower() or "job" in doc.lower(): | |
emoji = "πΌ" | |
elif "research" in doc.lower() or "science" in doc.lower(): | |
emoji = "π¬" | |
events_found.append( | |
f"{emoji} {metadata.get('title', 'Untitled Event')}\n" | |
f"Date & Time: {metadata.get('start_time', 'Time not specified')}\n" | |
f"Hosted by: {metadata.get('host', 'No host specified')}\n" | |
f"Type: {metadata.get('categories', 'General Event')}\n" | |
) | |
if not events_found: | |
return f"I couldn't find any events matching '{query}' at this time." | |
response = f"Here are some relevant events:\n\n" | |
response += "\n".join(events_found) | |
return response | |
except Exception as e: | |
print(f"Error querying events: {e}") | |
import traceback | |
print(f"Full error: {traceback.format_exc()}") | |
return "I encountered an error while searching for events. Please try again." | |
def process_event(self, entry) -> Dict: | |
"""Process a single event entry with improved parsing and error handling""" | |
try: | |
# Extract and parse datetime information | |
start_time = None | |
end_time = None | |
# First try to parse from HTML content | |
soup = BeautifulSoup(entry.get('summary', ''), 'html.parser') | |
# Look for start time | |
start_elem = soup.find('time', class_='dt-start') | |
if start_elem and 'datetime' in start_elem.attrs: | |
try: | |
start_time = parser.parse(start_elem['datetime']) | |
except (ParserError, ValueError) as e: | |
print(f"Error parsing start time: {e}") | |
# Look for end time | |
end_elem = soup.find('time', class_='dt-end') | |
if end_elem and 'datetime' in end_elem.attrs: | |
try: | |
end_time = parser.parse(end_elem['datetime']) | |
except (ParserError, ValueError) as e: | |
print(f"Error parsing end time: {e}") | |
# If HTML parsing failed, try RSS feed's native fields | |
if not start_time and 'start' in entry: | |
try: | |
start_time = parser.parse(entry.start) | |
except (ParserError, ValueError) as e: | |
print(f"Error parsing RSS start time: {e}") | |
if not end_time and 'end' in entry: | |
try: | |
end_time = parser.parse(entry.end) | |
except (ParserError, ValueError) as e: | |
print(f"Error parsing RSS end time: {e}") | |
# If still no start time, try published date as last resort | |
if not start_time and 'published_parsed' in entry: | |
start_time = datetime(*entry.published_parsed[:6]) | |
# Skip if no valid start time or outside date range | |
if not start_time: | |
print("No valid start time found for event") | |
return None | |
# Ensure timezone awareness | |
if not start_time.tzinfo: | |
start_time = self.eastern.localize(start_time) | |
# Check if event is within our date range | |
if not (self.today <= start_time <= self.date_range_end): | |
return None | |
# Extract base information | |
title = entry.get('title', 'No title') | |
# Extract author info - preferring name over email | |
author = None | |
if 'authors' in entry and entry.authors: | |
author = entry.authors[0].get('name', None) | |
if not author and 'author' in entry: | |
author = entry.author.split('(')[0].strip() # Clean up email format | |
# Get tags/categories | |
categories = [] | |
if 'tags' in entry: | |
categories = [tag.get('term', '') for tag in entry.tags] | |
categories_str = '; '.join(filter(None, categories)) or "General Event" | |
# Extract host if available | |
host = entry.get('host', 'No host specified') | |
# Create event text that emphasizes searchable metadata | |
event_text = f""" | |
Event: {title} | |
Department: {host} | |
Date & Time: {start_time.strftime('%B %d, %Y at %I:%M %p')} | |
Host: {author or host} | |
Type: {categories_str} | |
""" | |
# Add to metadata for better filtering | |
metadata = { | |
"title": title, | |
"author": author or host, | |
"categories": categories_str, | |
"start_time": start_time.strftime('%B %d, %Y at %I:%M %p'), # Add this line | |
"host": host, | |
"department": self.extract_department(title, host) # Helper function to categorize | |
} | |
return { | |
"text": event_text.strip(), | |
"metadata": metadata, | |
"id": f"{entry.get('id', '')}" | |
} | |
except Exception as e: | |
print(f"Error processing event {entry.get('title', 'Unknown')}: {e}") | |
import traceback | |
print(f"Full error: {traceback.format_exc()}") | |
return None | |
def extract_department(self, title: str, host: str) -> str: | |
"""Extract department information from title and host""" | |
text = f"{title} {host}".lower() | |
departments = { | |
'Mathematics': ['math', 'mathematics', 'statistics'], | |
'Sciences': ['science', 'biology', 'chemistry', 'physics'], | |
'Business': ['business', 'accounting', 'finance', 'management'], | |
'Arts': ['arts', 'humanities', 'visual arts', 'performing arts'], | |
'Engineering': ['engineering', 'technology', 'computing'], | |
'Social Sciences': ['psychology', 'sociology', 'political science'], | |
'International': ['international', 'global', 'abroad'], | |
'Student Life': ['student life', 'campus life', 'residence'], | |
'Athletics': ['athletics', 'sports', 'recreation'], | |
'Career': ['career', 'professional', 'employment'] | |
} | |
for dept, keywords in departments.items(): | |
if any(keyword in text for keyword in keywords): | |
return dept | |
return 'General' | |
def process_chat(self, message: str, history: List[Tuple[str, str]]) -> str: | |
"""Process chat messages and maintain context""" | |
try: | |
# Get RAG response | |
events_response = self.query_events(message) | |
# Format response in a conversational way - Pass both arguments | |
formatted_response = self.format_response(events_response, message) # Fixed: Added message as query | |
# Check if formatted response is empty | |
if not formatted_response: | |
formatted_response = "I couldn't find any events matching your query." | |
# Update chat history | |
self.chat_history = history + [(message, formatted_response)] | |
return formatted_response | |
except Exception as e: | |
return f"I apologize, but I encountered an error while searching for events: {str(e)}" | |
def format_response(self, events_text: str, query: str) -> str: | |
""" | |
Format the RAG response with an LLM-generated introduction | |
Args: | |
events_text (str): The events information from RAG | |
query (str): The original user query | |
Returns: | |
str: Formatted response with LLM intro and RAG results | |
""" | |
try: | |
if not events_text or events_text.strip() == "": | |
return "I couldn't find any events matching your query. Could you try rephrasing or being more specific?" | |
# Create prompt for introduction | |
intro_prompt = ChatPromptTemplate.from_template(self.RESPONSE_TEMPLATE) | |
# Generate introduction using LLM - Changed line! | |
intro_chain = intro_prompt | self.llm | StrOutputParser() | |
introduction = intro_chain.invoke(query) # Pass only the query string | |
# Format the RAG results with emojis | |
formatted_events = [] | |
events = events_text.split("\n\n") | |
for event in events: | |
if event.strip(): | |
# Add emoji based on event type/keywords | |
if "workshop" in event.lower(): | |
event = "π§ " + event | |
elif "seminar" in event.lower(): | |
event = "π " + event | |
elif "lecture" in event.lower(): | |
event = "π " + event | |
elif "research" in event.lower(): | |
event = "π¬ " + event | |
elif "sports" in event.lower(): | |
event = "π " + event | |
else: | |
event = "π " + event | |
formatted_events.append(event) | |
# Combine introduction and events | |
full_response = f"{introduction.strip()}\n\n" | |
full_response += "\n\n".join(formatted_events) | |
return full_response | |
except Exception as e: | |
print(f"Error in response formatting: {e}") | |
# Fallback to basic formatting if LLM fails | |
fallback_response = "Here are some events that might interest you:\n\n" | |
fallback_response += events_text | |
return fallback_response | |
def create_chat_interface(): | |
chat_rag = BrockEventsRAG() | |
custom_theme = gr.themes.Soft().set( | |
input_background_fill="*primary", | |
body_text_color="*secondary", | |
) | |
with gr.Blocks(theme=custom_theme) as demo: | |
# Header section | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# π RAG-gy Brock University Events Assistant") | |
gr.Markdown("Ask me about upcoming events, workshops, or activities!") | |
gr.Markdown(f""" | |
### System Information | |
- **Embeddings Model**: {chat_rag.emodel_name} | |
- **LLM Model**: {chat_rag.model_name} | |
- **Collection Size**: {chat_rag.collection.count()} documents | |
""") | |
# Add sliders | |
temperature = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.7, step=0.1, | |
label="Response Creativity (Temperature)" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
label="Response Focus (Top P)" | |
) | |
top_k = gr.Slider( | |
minimum=1, maximum=100, value=50, step=1, | |
label="Response Diversity (Top K)" | |
) | |
# Chat components | |
chatbot = gr.Chatbot( # Removed type="messages" | |
label="Chat History", | |
height=400, | |
bubble_full_width=False | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Your Question", | |
placeholder="e.g., What events are happening this week?", | |
scale=4 | |
) | |
submit = gr.Button("Ask", scale=1, variant="primary") | |
clear = gr.Button("Clear Chat") | |
# Event handlers | |
def process_chat(message, chat_history, temp, p, k): | |
chat_rag.temperature = temp | |
chat_rag.top_p = p | |
chat_rag.top_k = k | |
bot_message = chat_rag.process_chat(message, chat_history) | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
submit.click( | |
process_chat, | |
inputs=[msg, chatbot, temperature, top_p, top_k], | |
outputs=[msg, chatbot] | |
) | |
msg.submit( | |
process_chat, | |
inputs=[msg, chatbot, temperature, top_p, top_k], | |
outputs=[msg, chatbot] | |
) | |
clear.click(lambda: None, None, chatbot) | |
# Examples | |
gr.Examples( | |
examples=[ | |
"What workshops are happening next week?", | |
"Are there any business networking events coming up?", | |
"Tell me about math and science events", | |
"What's happening at the makerspace?", | |
"Are there any career fairs scheduled?" | |
], | |
inputs=msg | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_chat_interface() | |
demo.launch(share=True, debug=True) |