dataprincess commited on
Commit
48b4cd8
1 Parent(s): 9636641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -12
app.py CHANGED
@@ -25,7 +25,7 @@ with open(FILE_PATH, 'r') as file:
25
  pc = Pinecone(api_key=PINECONE_API_KEY)
26
  spec = ServerlessSpec(cloud="aws", region='us-east-1')
27
  existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
28
- # Check if index already exists; if not, create it
29
  if INDEX_NAME not in existing_indexes:
30
  pc.create_index(INDEX_NAME, dimension=DIMS, metric='cosine', spec=spec)
31
 
@@ -54,41 +54,32 @@ for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
54
  index.upsert(vectors=to_upsert)
55
 
56
  def extract_course_code(text) -> list[str]:
57
- # Improved pattern with correct case insensitivity and spacing allowance
58
  pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
59
  match = re.findall(pattern, text, re.IGNORECASE)
60
  return match if match else None
61
 
62
  def get_docs(query: str, top_k: int) -> list[str]:
63
- # Extract course code(s) from the query
64
  course_code = extract_course_code(query)
65
  exact_matches = []
66
 
67
  if course_code:
68
- # Normalize course_code to lowercase for case-insensitive matching
69
  course_code = [code.lower() for code in course_code]
70
 
71
- # Check for exact match in metadata
72
  exact_matches = [
73
  x['content'] for x in data['metadata']
74
  if any(code in x['content'].lower() for code in course_code)
75
  ]
76
 
77
- # Calculate remaining slots if we have fewer than top_k exact matches
78
  remaining_slots = top_k - len(exact_matches)
79
 
80
  if remaining_slots > 0:
81
- # Perform embedding search for either the entire top_k if no exact match, or the remaining slots
82
  xq = encoder.encode(query)
83
  res = index.query(vector=xq.tolist(), top_k=remaining_slots if exact_matches else top_k, include_metadata=True)
84
 
85
- # Add embedding-based matches (avoiding duplicates)
86
  embedding_matches = [x["metadata"]['content'] for x in res["matches"]]
87
 
88
- # Combine exact matches with embedding matches
89
  exact_matches.extend(embedding_matches)
90
 
91
- # Return the first top_k results
92
  return exact_matches[:top_k]
93
 
94
  def get_response(query: str, docs: list[str]) -> str:
@@ -114,10 +105,8 @@ def get_response(query: str, docs: list[str]) -> str:
114
 
115
  def handle_query(user_query: str):
116
 
117
- # Get relevant documents
118
  docs = get_docs(user_query, top_k=5)
119
 
120
- # Generate and return response
121
  response = get_response(user_query, docs=docs)
122
 
123
  for word in response.split():
 
25
  pc = Pinecone(api_key=PINECONE_API_KEY)
26
  spec = ServerlessSpec(cloud="aws", region='us-east-1')
27
  existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
28
+
29
  if INDEX_NAME not in existing_indexes:
30
  pc.create_index(INDEX_NAME, dimension=DIMS, metric='cosine', spec=spec)
31
 
 
54
  index.upsert(vectors=to_upsert)
55
 
56
  def extract_course_code(text) -> list[str]:
 
57
  pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
58
  match = re.findall(pattern, text, re.IGNORECASE)
59
  return match if match else None
60
 
61
  def get_docs(query: str, top_k: int) -> list[str]:
 
62
  course_code = extract_course_code(query)
63
  exact_matches = []
64
 
65
  if course_code:
 
66
  course_code = [code.lower() for code in course_code]
67
 
 
68
  exact_matches = [
69
  x['content'] for x in data['metadata']
70
  if any(code in x['content'].lower() for code in course_code)
71
  ]
72
 
 
73
  remaining_slots = top_k - len(exact_matches)
74
 
75
  if remaining_slots > 0:
 
76
  xq = encoder.encode(query)
77
  res = index.query(vector=xq.tolist(), top_k=remaining_slots if exact_matches else top_k, include_metadata=True)
78
 
 
79
  embedding_matches = [x["metadata"]['content'] for x in res["matches"]]
80
 
 
81
  exact_matches.extend(embedding_matches)
82
 
 
83
  return exact_matches[:top_k]
84
 
85
  def get_response(query: str, docs: list[str]) -> str:
 
105
 
106
  def handle_query(user_query: str):
107
 
 
108
  docs = get_docs(user_query, top_k=5)
109
 
 
110
  response = get_response(user_query, docs=docs)
111
 
112
  for word in response.split():