AryanJh commited on
Commit
6520230
Β·
verified Β·
1 Parent(s): 6e0003d

Enchanced event matching

Browse files

The simplified event matcher, though efficient, didn't work nicely

Files changed (1) hide show
  1. app.py +154 -174
app.py CHANGED
@@ -12,199 +12,179 @@ import diskcache
12
  import os
13
  import chromadb
14
 
15
- class SimplifiedBrockEventsRAG:
 
 
 
 
 
 
16
  def __init__(self):
17
- """Initialize simplified RAG system for CPU environment"""
18
- print("Initializing simplified RAG system...")
19
-
20
- # Initialize embedding model
21
- self.tokenizer = AutoTokenizer.from_pretrained("microsoft/mpnet-base")
22
- self.model = AutoModel.from_pretrained("microsoft/mpnet-base")
23
 
24
- # Force CPU usage
25
- self.device = torch.device('cpu')
26
- self.model.to(self.device)
27
-
28
- # Set up disk cache
29
- cache_dir = os.path.join(os.getcwd(), "cache")
30
- os.makedirs(cache_dir, exist_ok=True)
31
- self.cache = diskcache.Cache(cache_dir)
32
-
33
- # Initialize vector store
34
- self.chroma_client = chromadb.Client()
35
-
36
- # Initialize date handling
37
- self.eastern = pytz.timezone('America/New_York')
38
- self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
39
- self.date_range_end = self.today + timedelta(days=14)
40
-
41
- try:
42
- self.collection = self.chroma_client.create_collection(
43
- name="brock_events",
44
- metadata={"description": "Brock University Events Database"}
45
- )
46
- except Exception:
47
- self.chroma_client.delete_collection("brock_events")
48
- self.collection = self.chroma_client.create_collection(
49
- name="brock_events",
50
- metadata={"description": "Brock University Events Database"}
51
- )
52
-
53
- self.load_patterns()
54
-
55
- def load_patterns(self):
56
- """Load optimized search patterns"""
57
- self.patterns = {
58
- 'faculty': {
59
- 'math': ['mathematics', 'math', 'stats', 'computer science'],
60
- 'humanities': ['humanities', 'language', 'literature'],
61
- 'business': ['goodman', 'business', 'accounting'],
62
- 'science': ['science', 'biology', 'chemistry', 'physics']
63
- },
64
- 'event_type': {
65
- 'academic': ['lecture', 'seminar', 'workshop', 'conference'],
66
- 'social': ['meetup', 'social', 'gathering', 'networking'],
67
- 'career': ['career', 'job', 'employment', 'professional']
68
- },
69
- 'location': {
70
- 'online': ['online', 'virtual', 'zoom', 'teams'],
71
- 'campus': ['room', 'hall', 'building', 'plaza'],
72
- 'library': ['library', 'learning commons', 'makerspace']
73
- }
74
- }
75
-
76
- @lru_cache(maxsize=128)
77
- def generate_embedding(self, text: str) -> List[float]:
78
- """Generate embedding using MPNet"""
79
- inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
80
- with torch.no_grad():
81
- outputs = self.model(**inputs)
82
- embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
83
- return F.normalize(embeddings, p=2, dim=1)[0].tolist()
84
 
85
- def mean_pooling(self, model_output, attention_mask):
86
- """Perform mean pooling on token embeddings"""
87
- token_embeddings = model_output[0]
88
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
89
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
90
-
91
- @lru_cache(maxsize=128)
92
- def preprocess_query(self, query: str) -> str:
93
- """Efficient query preprocessing"""
94
- query = re.sub(r'[^\w\s]', ' ', query.lower())
95
- replacements = {
96
- 'fms': 'faculty of mathematics and science',
97
- 'gsb': 'goodman school of business',
98
- 'foh': 'faculty of humanities'
99
- }
100
- for abbr, full in replacements.items():
101
- query = query.replace(abbr, full)
102
- return query.strip()
 
 
 
103
 
104
- def semantic_search(self, query: str, k: int = 3) -> List[Dict]:
105
- """Optimized semantic search"""
106
- # Check cache first
107
- cache_key = f"search_{query}_{k}"
108
- if cache_key in self.cache:
109
- return self.cache[cache_key]
110
-
111
- # Process query and get embeddings
112
- processed_query = self.preprocess_query(query)
113
- query_embedding = self.generate_embedding(processed_query)
114
-
115
- # Get results from vector store
116
- results = self.collection.query(
117
- query_embeddings=[query_embedding],
118
- n_results=k,
119
- include=['documents', 'metadatas', 'distances']
120
- )
121
-
122
- # Process and rank results
123
- processed_results = []
124
- for doc, metadata, distance in zip(
125
- results['documents'][0],
126
- results['metadatas'][0],
127
- results['distances'][0]
128
- ):
129
- # Calculate relevance score
130
- relevance_score = self.calculate_relevance(query, doc, metadata)
131
- processed_results.append({
132
- 'document': doc,
133
- 'metadata': metadata,
134
- 'score': relevance_score
135
- })
136
 
137
- # Sort by relevance
138
- processed_results.sort(key=lambda x: x['score'], reverse=True)
 
 
 
139
 
140
- # Cache results
141
- self.cache[cache_key] = processed_results
142
- return processed_results
 
 
 
143
 
144
- def calculate_relevance(self, query: str, document: str, metadata: Dict) -> float:
145
- """Calculate optimized relevance score"""
146
  score = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  query_lower = query.lower()
148
 
149
- # Title similarity (40%)
150
- title_similarity = fuzz.ratio(query_lower, metadata['title'].lower()) / 100
151
- score += title_similarity * 0.4
152
-
153
- # Category matching (30%)
154
- if 'categories' in metadata:
155
- categories_lower = metadata['categories'].lower()
156
- for category_type in self.patterns.values():
157
- for keywords in category_type.values():
158
- if any(keyword in query_lower for keyword in keywords):
159
- if any(keyword in categories_lower for keyword in keywords):
160
- score += 0.3
161
- break
162
-
163
- # Location matching (30%)
164
- if 'location' in metadata:
165
- location_lower = metadata['location'].lower()
166
- for keywords in self.patterns['location'].values():
167
- if any(keyword in query_lower for keyword in keywords):
168
- if any(keyword in location_lower for keyword in keywords):
169
- score += 0.3
170
- break
171
-
172
  return score
173
 
174
- def generate_response(self, query: str, results: List[Dict]) -> str:
175
- """Generate optimized response"""
176
- if not results:
177
- return "I couldn't find any events matching your query. Try asking in a different way!"
178
-
179
- response = "Here are some relevant events I found:\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Format results
182
- for i, result in enumerate(results, 1):
183
- metadata = result['metadata']
184
- location = metadata['location']
185
- is_online = any(term in location.lower()
186
- for term in self.patterns['location']['online'])
187
 
188
- response += f"{i}. **{metadata['title']}**\n"
189
- response += f"πŸ“… {metadata['date']} at {metadata['time']}\n"
190
- response += f"{'πŸ“±' if is_online else 'πŸ“'} {location}\n"
191
- if 'categories' in metadata:
192
- response += f"🏷️ {metadata['categories']}\n"
193
- response += f"πŸ”— More info: {metadata['link']}\n\n"
 
 
 
 
194
 
195
- # Add contextual suggestion
196
- if any(keyword in query.lower() for keyword in self.patterns['faculty']['math']):
197
- response += "You can ask about events from other faculties too!\n"
198
- elif any(keyword in query.lower() for keyword in self.patterns['location']['library']):
199
- response += "I can help you find events in other locations too!\n"
200
- else:
201
- response += "Feel free to ask about specific types of events!\n"
202
 
203
- return response
 
204
 
205
  def create_demo():
206
  """Create optimized Gradio interface"""
207
- rag_system = SimplifiedBrockEventsRAG()
208
 
209
  def process_query(message: str, history: list) -> Tuple[str, list]:
210
  """Process query and generate response"""
 
12
  import os
13
  import chromadb
14
 
15
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
16
+ import torch
17
+ from typing import List, Dict, Tuple
18
+ import pytz
19
+ from fuzzywuzzy import fuzz
20
+
21
+ class EnhancedEventMatcher:
22
  def __init__(self):
23
+ """Initialize the enhanced event matcher with T5"""
24
+ # Initialize T5 for response enhancement
25
+ self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
26
+ self.t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
 
 
27
 
28
+ # Initialize pattern learning
29
+ self.known_categories = set()
30
+ self.known_hosts = set()
31
+ self.known_locations = set()
32
+ self.faculty_patterns = {}
33
+ self.category_patterns = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ def learn_from_events(self, events: List[Event]):
36
+ """Learn patterns from existing events"""
37
+ # Original DynamicEventMatcher learning logic
38
+ for event in events:
39
+ self.known_categories.update(event.categories)
40
+ self.known_hosts.update(event.hosts)
41
+ self.known_locations.add(event.location)
42
+
43
+ # Learn faculty associations
44
+ for host in event.hosts:
45
+ for category in event.categories:
46
+ key = (host, category)
47
+ if 'faculty' in host.lower():
48
+ self.faculty_patterns[key] = self.faculty_patterns.get(key, 0) + 1
49
+
50
+ # Learn category associations
51
+ for cat1 in event.categories:
52
+ for cat2 in event.categories:
53
+ if cat1 != cat2:
54
+ key = (cat1, cat2)
55
+ self.category_patterns[key] = self.category_patterns.get(key, 0) + 1
56
 
57
+ def get_faculty_score(self, event: Event, query: str) -> float:
58
+ """Original faculty scoring logic"""
59
+ score = 0.0
60
+ query_lower = query.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ for host in event.hosts:
63
+ if 'faculty' in host.lower():
64
+ ratio = fuzz.partial_ratio(query_lower, host.lower())
65
+ if ratio > 80:
66
+ score += 2.0 * (ratio / 100)
67
 
68
+ for category in event.categories:
69
+ for (host, cat), count in self.faculty_patterns.items():
70
+ if category == cat and fuzz.partial_ratio(query_lower, host.lower()) > 80:
71
+ score += 1.0 * (count / max(self.faculty_patterns.values()))
72
+
73
+ return score
74
 
75
+ def get_category_score(self, event: Event, query_type: str) -> float:
76
+ """Original category scoring logic"""
77
  score = 0.0
78
+ if not query_type:
79
+ return score
80
+
81
+ for category in event.categories:
82
+ ratio = fuzz.partial_ratio(query_type.lower(), category.lower())
83
+ if ratio > 80:
84
+ score += 1.5 * (ratio / 100)
85
+
86
+ for (cat1, cat2), count in self.category_patterns.items():
87
+ if category == cat1 and fuzz.partial_ratio(query_type.lower(), cat2.lower()) > 80:
88
+ score += 0.5 * (count / max(self.category_patterns.values()))
89
+
90
+ return score
91
+
92
+ def get_location_score(self, event: Event, query: str) -> float:
93
+ """Original location scoring logic"""
94
+ score = 0.0
95
+ location_lower = event.location.lower()
96
  query_lower = query.lower()
97
 
98
+ online_terms = {'online', 'virtual', 'teams', 'zoom'}
99
+ if any(term in query_lower for term in online_terms):
100
+ if any(term in location_lower for term in online_terms):
101
+ score += 1.5
102
+
103
+ campus_terms = {'room', 'hall', 'building', 'plaza', 'campus'}
104
+ if any(term in query_lower for term in {'in-person', 'campus', 'building'}):
105
+ if any(term in location_lower for term in campus_terms):
106
+ score += 1.5
107
+
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return score
109
 
110
+ def enhance_response(self, matched_events: List[Tuple[Event, float]], query: str) -> str:
111
+ """Use T5 to enhance response generation"""
112
+ # Format events for T5 input
113
+ events_text = ""
114
+ for event, score in matched_events:
115
+ events_text += f"""
116
+ Event: {event.title}
117
+ Date: {event.start_time.strftime('%A, %B %d, %Y')}
118
+ Time: {event.start_time.strftime('%I:%M %p')}
119
+ Location: {event.location}
120
+ Categories: {', '.join(event.categories)}
121
+ Score: {score:.2f}
122
+ """
123
+
124
+ # Create prompt for T5
125
+ prompt = f"""
126
+ Query: {query}
127
+ Available Events:
128
+ {events_text}
129
+
130
+ Generate a natural response highlighting the most relevant events and their details.
131
+ """
132
+
133
+ # Generate enhanced response
134
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
135
+ outputs = self.t5_model.generate(inputs, max_length=300, num_beams=4, temperature=0.7)
136
+ enhanced_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
137
+
138
+ # Format final response with event details
139
+ final_response = enhanced_response + "\n\n"
140
+ for event, score in matched_events:
141
+ location_icon = "πŸ“±" if any(term in event.location.lower()
142
+ for term in ['teams', 'zoom', 'online']) else "πŸ“"
143
+
144
+ final_response += f"""
145
+ **{event.title}** {'🌟' * int(min(score, 5))}
146
+ πŸ“… {event.start_time.strftime('%A, %B %d, %Y')} at {event.start_time.strftime('%I:%M %p')}
147
+ {location_icon} {event.location}
148
+ πŸ‘₯ Hosted by: {', '.join(event.hosts)}
149
+ 🏷️ Categories: {', '.join(event.categories)}
150
+ πŸ”— {event.link}
151
+ """
152
+
153
+ return final_response
154
+
155
+ def match_and_respond(self, events: List[Event], query: str, query_info: Dict) -> str:
156
+ """Main method to match events and generate response"""
157
+ # Learn patterns if not already learned
158
+ self.learn_from_events(events)
159
 
160
+ # Match events using original logic
161
+ matched_events = []
162
+ for event in events:
163
+ faculty_score = self.get_faculty_score(event, query_info['original_query'])
164
+ category_score = self.get_category_score(event, query_info['event_type'])
165
+ location_score = self.get_location_score(event, query_info['original_query'])
166
 
167
+ total_score = (faculty_score * 1.5 +
168
+ category_score * 1.2 +
169
+ location_score * 1.0)
170
+
171
+ if total_score > 0:
172
+ matched_events.append((event, total_score))
173
+
174
+ # Sort and get top matches
175
+ matched_events.sort(key=lambda x: x[1], reverse=True)
176
+ top_matches = matched_events[:3]
177
 
178
+ if not top_matches:
179
+ return f"I couldn't find any events matching your query for {query_info['faculty'] or 'any faculty'} " \
180
+ f"and {query_info['event_type'] or 'any event type'}. Try broadening your search."
 
 
 
 
181
 
182
+ # Generate enhanced response using T5
183
+ return self.enhance_response(top_matches, query)
184
 
185
  def create_demo():
186
  """Create optimized Gradio interface"""
187
+ rag_system = EnhancedEventMatcher()
188
 
189
  def process_query(message: str, history: list) -> Tuple[str, list]:
190
  """Process query and generate response"""