Spaces:
Sleeping
Sleeping
dataprincess
commited on
Commit
•
48b4cd8
1
Parent(s):
9636641
Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,7 @@ with open(FILE_PATH, 'r') as file:
|
|
25 |
pc = Pinecone(api_key=PINECONE_API_KEY)
|
26 |
spec = ServerlessSpec(cloud="aws", region='us-east-1')
|
27 |
existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
|
28 |
-
|
29 |
if INDEX_NAME not in existing_indexes:
|
30 |
pc.create_index(INDEX_NAME, dimension=DIMS, metric='cosine', spec=spec)
|
31 |
|
@@ -54,41 +54,32 @@ for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
|
|
54 |
index.upsert(vectors=to_upsert)
|
55 |
|
56 |
def extract_course_code(text) -> list[str]:
|
57 |
-
# Improved pattern with correct case insensitivity and spacing allowance
|
58 |
pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
|
59 |
match = re.findall(pattern, text, re.IGNORECASE)
|
60 |
return match if match else None
|
61 |
|
62 |
def get_docs(query: str, top_k: int) -> list[str]:
|
63 |
-
# Extract course code(s) from the query
|
64 |
course_code = extract_course_code(query)
|
65 |
exact_matches = []
|
66 |
|
67 |
if course_code:
|
68 |
-
# Normalize course_code to lowercase for case-insensitive matching
|
69 |
course_code = [code.lower() for code in course_code]
|
70 |
|
71 |
-
# Check for exact match in metadata
|
72 |
exact_matches = [
|
73 |
x['content'] for x in data['metadata']
|
74 |
if any(code in x['content'].lower() for code in course_code)
|
75 |
]
|
76 |
|
77 |
-
# Calculate remaining slots if we have fewer than top_k exact matches
|
78 |
remaining_slots = top_k - len(exact_matches)
|
79 |
|
80 |
if remaining_slots > 0:
|
81 |
-
# Perform embedding search for either the entire top_k if no exact match, or the remaining slots
|
82 |
xq = encoder.encode(query)
|
83 |
res = index.query(vector=xq.tolist(), top_k=remaining_slots if exact_matches else top_k, include_metadata=True)
|
84 |
|
85 |
-
# Add embedding-based matches (avoiding duplicates)
|
86 |
embedding_matches = [x["metadata"]['content'] for x in res["matches"]]
|
87 |
|
88 |
-
# Combine exact matches with embedding matches
|
89 |
exact_matches.extend(embedding_matches)
|
90 |
|
91 |
-
# Return the first top_k results
|
92 |
return exact_matches[:top_k]
|
93 |
|
94 |
def get_response(query: str, docs: list[str]) -> str:
|
@@ -114,10 +105,8 @@ def get_response(query: str, docs: list[str]) -> str:
|
|
114 |
|
115 |
def handle_query(user_query: str):
|
116 |
|
117 |
-
# Get relevant documents
|
118 |
docs = get_docs(user_query, top_k=5)
|
119 |
|
120 |
-
# Generate and return response
|
121 |
response = get_response(user_query, docs=docs)
|
122 |
|
123 |
for word in response.split():
|
|
|
25 |
pc = Pinecone(api_key=PINECONE_API_KEY)
|
26 |
spec = ServerlessSpec(cloud="aws", region='us-east-1')
|
27 |
existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
|
28 |
+
|
29 |
if INDEX_NAME not in existing_indexes:
|
30 |
pc.create_index(INDEX_NAME, dimension=DIMS, metric='cosine', spec=spec)
|
31 |
|
|
|
54 |
index.upsert(vectors=to_upsert)
|
55 |
|
56 |
def extract_course_code(text) -> list[str]:
|
|
|
57 |
pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
|
58 |
match = re.findall(pattern, text, re.IGNORECASE)
|
59 |
return match if match else None
|
60 |
|
61 |
def get_docs(query: str, top_k: int) -> list[str]:
|
|
|
62 |
course_code = extract_course_code(query)
|
63 |
exact_matches = []
|
64 |
|
65 |
if course_code:
|
|
|
66 |
course_code = [code.lower() for code in course_code]
|
67 |
|
|
|
68 |
exact_matches = [
|
69 |
x['content'] for x in data['metadata']
|
70 |
if any(code in x['content'].lower() for code in course_code)
|
71 |
]
|
72 |
|
|
|
73 |
remaining_slots = top_k - len(exact_matches)
|
74 |
|
75 |
if remaining_slots > 0:
|
|
|
76 |
xq = encoder.encode(query)
|
77 |
res = index.query(vector=xq.tolist(), top_k=remaining_slots if exact_matches else top_k, include_metadata=True)
|
78 |
|
|
|
79 |
embedding_matches = [x["metadata"]['content'] for x in res["matches"]]
|
80 |
|
|
|
81 |
exact_matches.extend(embedding_matches)
|
82 |
|
|
|
83 |
return exact_matches[:top_k]
|
84 |
|
85 |
def get_response(query: str, docs: list[str]) -> str:
|
|
|
105 |
|
106 |
def handle_query(user_query: str):
|
107 |
|
|
|
108 |
docs = get_docs(user_query, top_k=5)
|
109 |
|
|
|
110 |
response = get_response(user_query, docs=docs)
|
111 |
|
112 |
for word in response.split():
|