Spaces:
Sleeping
Sleeping
Enchanced event matching
Browse filesThe simplified event matcher, though efficient, didn't work nicely
app.py
CHANGED
@@ -12,199 +12,179 @@ import diskcache
|
|
12 |
import os
|
13 |
import chromadb
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def __init__(self):
|
17 |
-
"""Initialize
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/mpnet-base")
|
22 |
-
self.model = AutoModel.from_pretrained("microsoft/mpnet-base")
|
23 |
|
24 |
-
#
|
25 |
-
self.
|
26 |
-
self.
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
os.makedirs(cache_dir, exist_ok=True)
|
31 |
-
self.cache = diskcache.Cache(cache_dir)
|
32 |
-
|
33 |
-
# Initialize vector store
|
34 |
-
self.chroma_client = chromadb.Client()
|
35 |
-
|
36 |
-
# Initialize date handling
|
37 |
-
self.eastern = pytz.timezone('America/New_York')
|
38 |
-
self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
|
39 |
-
self.date_range_end = self.today + timedelta(days=14)
|
40 |
-
|
41 |
-
try:
|
42 |
-
self.collection = self.chroma_client.create_collection(
|
43 |
-
name="brock_events",
|
44 |
-
metadata={"description": "Brock University Events Database"}
|
45 |
-
)
|
46 |
-
except Exception:
|
47 |
-
self.chroma_client.delete_collection("brock_events")
|
48 |
-
self.collection = self.chroma_client.create_collection(
|
49 |
-
name="brock_events",
|
50 |
-
metadata={"description": "Brock University Events Database"}
|
51 |
-
)
|
52 |
-
|
53 |
-
self.load_patterns()
|
54 |
-
|
55 |
-
def load_patterns(self):
|
56 |
-
"""Load optimized search patterns"""
|
57 |
-
self.patterns = {
|
58 |
-
'faculty': {
|
59 |
-
'math': ['mathematics', 'math', 'stats', 'computer science'],
|
60 |
-
'humanities': ['humanities', 'language', 'literature'],
|
61 |
-
'business': ['goodman', 'business', 'accounting'],
|
62 |
-
'science': ['science', 'biology', 'chemistry', 'physics']
|
63 |
-
},
|
64 |
-
'event_type': {
|
65 |
-
'academic': ['lecture', 'seminar', 'workshop', 'conference'],
|
66 |
-
'social': ['meetup', 'social', 'gathering', 'networking'],
|
67 |
-
'career': ['career', 'job', 'employment', 'professional']
|
68 |
-
},
|
69 |
-
'location': {
|
70 |
-
'online': ['online', 'virtual', 'zoom', 'teams'],
|
71 |
-
'campus': ['room', 'hall', 'building', 'plaza'],
|
72 |
-
'library': ['library', 'learning commons', 'makerspace']
|
73 |
-
}
|
74 |
-
}
|
75 |
-
|
76 |
-
@lru_cache(maxsize=128)
|
77 |
-
def generate_embedding(self, text: str) -> List[float]:
|
78 |
-
"""Generate embedding using MPNet"""
|
79 |
-
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
80 |
-
with torch.no_grad():
|
81 |
-
outputs = self.model(**inputs)
|
82 |
-
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
|
83 |
-
return F.normalize(embeddings, p=2, dim=1)[0].tolist()
|
84 |
|
85 |
-
def
|
86 |
-
"""
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
103 |
|
104 |
-
def
|
105 |
-
"""
|
106 |
-
|
107 |
-
|
108 |
-
if cache_key in self.cache:
|
109 |
-
return self.cache[cache_key]
|
110 |
-
|
111 |
-
# Process query and get embeddings
|
112 |
-
processed_query = self.preprocess_query(query)
|
113 |
-
query_embedding = self.generate_embedding(processed_query)
|
114 |
-
|
115 |
-
# Get results from vector store
|
116 |
-
results = self.collection.query(
|
117 |
-
query_embeddings=[query_embedding],
|
118 |
-
n_results=k,
|
119 |
-
include=['documents', 'metadatas', 'distances']
|
120 |
-
)
|
121 |
-
|
122 |
-
# Process and rank results
|
123 |
-
processed_results = []
|
124 |
-
for doc, metadata, distance in zip(
|
125 |
-
results['documents'][0],
|
126 |
-
results['metadatas'][0],
|
127 |
-
results['distances'][0]
|
128 |
-
):
|
129 |
-
# Calculate relevance score
|
130 |
-
relevance_score = self.calculate_relevance(query, doc, metadata)
|
131 |
-
processed_results.append({
|
132 |
-
'document': doc,
|
133 |
-
'metadata': metadata,
|
134 |
-
'score': relevance_score
|
135 |
-
})
|
136 |
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
143 |
|
144 |
-
def
|
145 |
-
"""
|
146 |
score = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
query_lower = query.lower()
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
for
|
157 |
-
|
158 |
-
|
159 |
-
if any(keyword in categories_lower for keyword in keywords):
|
160 |
-
score += 0.3
|
161 |
-
break
|
162 |
-
|
163 |
-
# Location matching (30%)
|
164 |
-
if 'location' in metadata:
|
165 |
-
location_lower = metadata['location'].lower()
|
166 |
-
for keywords in self.patterns['location'].values():
|
167 |
-
if any(keyword in query_lower for keyword in keywords):
|
168 |
-
if any(keyword in location_lower for keyword in keywords):
|
169 |
-
score += 0.3
|
170 |
-
break
|
171 |
-
|
172 |
return score
|
173 |
|
174 |
-
def
|
175 |
-
"""
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
-
#
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
elif any(keyword in query.lower() for keyword in self.patterns['location']['library']):
|
199 |
-
response += "I can help you find events in other locations too!\n"
|
200 |
-
else:
|
201 |
-
response += "Feel free to ask about specific types of events!\n"
|
202 |
|
203 |
-
|
|
|
204 |
|
205 |
def create_demo():
|
206 |
"""Create optimized Gradio interface"""
|
207 |
-
rag_system =
|
208 |
|
209 |
def process_query(message: str, history: list) -> Tuple[str, list]:
|
210 |
"""Process query and generate response"""
|
|
|
12 |
import os
|
13 |
import chromadb
|
14 |
|
15 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
16 |
+
import torch
|
17 |
+
from typing import List, Dict, Tuple
|
18 |
+
import pytz
|
19 |
+
from fuzzywuzzy import fuzz
|
20 |
+
|
21 |
+
class EnhancedEventMatcher:
|
22 |
def __init__(self):
|
23 |
+
"""Initialize the enhanced event matcher with T5"""
|
24 |
+
# Initialize T5 for response enhancement
|
25 |
+
self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
26 |
+
self.t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
|
|
|
27 |
|
28 |
+
# Initialize pattern learning
|
29 |
+
self.known_categories = set()
|
30 |
+
self.known_hosts = set()
|
31 |
+
self.known_locations = set()
|
32 |
+
self.faculty_patterns = {}
|
33 |
+
self.category_patterns = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
def learn_from_events(self, events: List[Event]):
|
36 |
+
"""Learn patterns from existing events"""
|
37 |
+
# Original DynamicEventMatcher learning logic
|
38 |
+
for event in events:
|
39 |
+
self.known_categories.update(event.categories)
|
40 |
+
self.known_hosts.update(event.hosts)
|
41 |
+
self.known_locations.add(event.location)
|
42 |
+
|
43 |
+
# Learn faculty associations
|
44 |
+
for host in event.hosts:
|
45 |
+
for category in event.categories:
|
46 |
+
key = (host, category)
|
47 |
+
if 'faculty' in host.lower():
|
48 |
+
self.faculty_patterns[key] = self.faculty_patterns.get(key, 0) + 1
|
49 |
+
|
50 |
+
# Learn category associations
|
51 |
+
for cat1 in event.categories:
|
52 |
+
for cat2 in event.categories:
|
53 |
+
if cat1 != cat2:
|
54 |
+
key = (cat1, cat2)
|
55 |
+
self.category_patterns[key] = self.category_patterns.get(key, 0) + 1
|
56 |
|
57 |
+
def get_faculty_score(self, event: Event, query: str) -> float:
|
58 |
+
"""Original faculty scoring logic"""
|
59 |
+
score = 0.0
|
60 |
+
query_lower = query.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
for host in event.hosts:
|
63 |
+
if 'faculty' in host.lower():
|
64 |
+
ratio = fuzz.partial_ratio(query_lower, host.lower())
|
65 |
+
if ratio > 80:
|
66 |
+
score += 2.0 * (ratio / 100)
|
67 |
|
68 |
+
for category in event.categories:
|
69 |
+
for (host, cat), count in self.faculty_patterns.items():
|
70 |
+
if category == cat and fuzz.partial_ratio(query_lower, host.lower()) > 80:
|
71 |
+
score += 1.0 * (count / max(self.faculty_patterns.values()))
|
72 |
+
|
73 |
+
return score
|
74 |
|
75 |
+
def get_category_score(self, event: Event, query_type: str) -> float:
|
76 |
+
"""Original category scoring logic"""
|
77 |
score = 0.0
|
78 |
+
if not query_type:
|
79 |
+
return score
|
80 |
+
|
81 |
+
for category in event.categories:
|
82 |
+
ratio = fuzz.partial_ratio(query_type.lower(), category.lower())
|
83 |
+
if ratio > 80:
|
84 |
+
score += 1.5 * (ratio / 100)
|
85 |
+
|
86 |
+
for (cat1, cat2), count in self.category_patterns.items():
|
87 |
+
if category == cat1 and fuzz.partial_ratio(query_type.lower(), cat2.lower()) > 80:
|
88 |
+
score += 0.5 * (count / max(self.category_patterns.values()))
|
89 |
+
|
90 |
+
return score
|
91 |
+
|
92 |
+
def get_location_score(self, event: Event, query: str) -> float:
|
93 |
+
"""Original location scoring logic"""
|
94 |
+
score = 0.0
|
95 |
+
location_lower = event.location.lower()
|
96 |
query_lower = query.lower()
|
97 |
|
98 |
+
online_terms = {'online', 'virtual', 'teams', 'zoom'}
|
99 |
+
if any(term in query_lower for term in online_terms):
|
100 |
+
if any(term in location_lower for term in online_terms):
|
101 |
+
score += 1.5
|
102 |
+
|
103 |
+
campus_terms = {'room', 'hall', 'building', 'plaza', 'campus'}
|
104 |
+
if any(term in query_lower for term in {'in-person', 'campus', 'building'}):
|
105 |
+
if any(term in location_lower for term in campus_terms):
|
106 |
+
score += 1.5
|
107 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
return score
|
109 |
|
110 |
+
def enhance_response(self, matched_events: List[Tuple[Event, float]], query: str) -> str:
|
111 |
+
"""Use T5 to enhance response generation"""
|
112 |
+
# Format events for T5 input
|
113 |
+
events_text = ""
|
114 |
+
for event, score in matched_events:
|
115 |
+
events_text += f"""
|
116 |
+
Event: {event.title}
|
117 |
+
Date: {event.start_time.strftime('%A, %B %d, %Y')}
|
118 |
+
Time: {event.start_time.strftime('%I:%M %p')}
|
119 |
+
Location: {event.location}
|
120 |
+
Categories: {', '.join(event.categories)}
|
121 |
+
Score: {score:.2f}
|
122 |
+
"""
|
123 |
+
|
124 |
+
# Create prompt for T5
|
125 |
+
prompt = f"""
|
126 |
+
Query: {query}
|
127 |
+
Available Events:
|
128 |
+
{events_text}
|
129 |
+
|
130 |
+
Generate a natural response highlighting the most relevant events and their details.
|
131 |
+
"""
|
132 |
+
|
133 |
+
# Generate enhanced response
|
134 |
+
inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
|
135 |
+
outputs = self.t5_model.generate(inputs, max_length=300, num_beams=4, temperature=0.7)
|
136 |
+
enhanced_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
137 |
+
|
138 |
+
# Format final response with event details
|
139 |
+
final_response = enhanced_response + "\n\n"
|
140 |
+
for event, score in matched_events:
|
141 |
+
location_icon = "π±" if any(term in event.location.lower()
|
142 |
+
for term in ['teams', 'zoom', 'online']) else "π"
|
143 |
+
|
144 |
+
final_response += f"""
|
145 |
+
**{event.title}** {'π' * int(min(score, 5))}
|
146 |
+
π
{event.start_time.strftime('%A, %B %d, %Y')} at {event.start_time.strftime('%I:%M %p')}
|
147 |
+
{location_icon} {event.location}
|
148 |
+
π₯ Hosted by: {', '.join(event.hosts)}
|
149 |
+
π·οΈ Categories: {', '.join(event.categories)}
|
150 |
+
π {event.link}
|
151 |
+
"""
|
152 |
+
|
153 |
+
return final_response
|
154 |
+
|
155 |
+
def match_and_respond(self, events: List[Event], query: str, query_info: Dict) -> str:
|
156 |
+
"""Main method to match events and generate response"""
|
157 |
+
# Learn patterns if not already learned
|
158 |
+
self.learn_from_events(events)
|
159 |
|
160 |
+
# Match events using original logic
|
161 |
+
matched_events = []
|
162 |
+
for event in events:
|
163 |
+
faculty_score = self.get_faculty_score(event, query_info['original_query'])
|
164 |
+
category_score = self.get_category_score(event, query_info['event_type'])
|
165 |
+
location_score = self.get_location_score(event, query_info['original_query'])
|
166 |
|
167 |
+
total_score = (faculty_score * 1.5 +
|
168 |
+
category_score * 1.2 +
|
169 |
+
location_score * 1.0)
|
170 |
+
|
171 |
+
if total_score > 0:
|
172 |
+
matched_events.append((event, total_score))
|
173 |
+
|
174 |
+
# Sort and get top matches
|
175 |
+
matched_events.sort(key=lambda x: x[1], reverse=True)
|
176 |
+
top_matches = matched_events[:3]
|
177 |
|
178 |
+
if not top_matches:
|
179 |
+
return f"I couldn't find any events matching your query for {query_info['faculty'] or 'any faculty'} " \
|
180 |
+
f"and {query_info['event_type'] or 'any event type'}. Try broadening your search."
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
# Generate enhanced response using T5
|
183 |
+
return self.enhance_response(top_matches, query)
|
184 |
|
185 |
def create_demo():
|
186 |
"""Create optimized Gradio interface"""
|
187 |
+
rag_system = EnhancedEventMatcher()
|
188 |
|
189 |
def process_query(message: str, history: list) -> Tuple[str, list]:
|
190 |
"""Process query and generate response"""
|