renceabishek commited on
Commit
cee126a
Β·
1 Parent(s): fd07624

fixing retrieve context

Browse files
Files changed (1) hide show
  1. main.py +21 -13
main.py CHANGED
@@ -41,19 +41,27 @@ index.add(embeddings)
41
  # === Retrieval function ===
42
  def retrieve_context(query, top_k=3):
43
  query_embedding = embedder.encode([query])
44
- _, indices = index.search(query_embedding, top_k)
45
- selected_chunks = [all_chunks[i] for i in indices[0]]
46
-
47
- # Filter out personal info unless query explicitly asks for it
48
- personal_keywords = ["email", "contact", "phone", "location", "website", "name"]
49
- if not any(keyword in query.lower() for keyword in personal_keywords):
50
- selected_chunks = [chunk for chunk in selected_chunks if "Personal Information" not in chunk]
51
- print("πŸ” Filtered chunks:\n", selected_chunks)
52
- # If filtering removed all chunks, fall back to original top_k
53
- if not selected_chunks:
54
- selected_chunks = [all_chunks[i] for i in indices[0]]
55
-
56
- return "\n\n".join(selected_chunks)
 
 
 
 
 
 
 
 
57
 
58
  # === Load QA model ===
59
  qa_pipeline = pipeline("question-answering", model="deepset/tinyroberta-squad2")
 
41
  # === Retrieval function ===
42
  def retrieve_context(query, top_k=3):
43
  query_embedding = embedder.encode([query])
44
+ scores, indices = index.search(query_embedding, top_k)
45
+
46
+ selected_chunks = []
47
+ for i, score in zip(indices[0], scores[0]):
48
+ chunk = all_chunks[i]
49
+ # Skip short or noisy chunks unless query matches
50
+ if len(chunk.split()) < 10 and not any(k in query.lower() for k in ["salary", "notice", "job", "current"]):
51
+ continue
52
+ selected_chunks.append((chunk, score))
53
+
54
+ # If nothing survives filtering, fall back to original top_k
55
+ if not selected_chunks:
56
+ selected_chunks = [(all_chunks[i], scores[0][j]) for j, i in enumerate(indices[0])]
57
+
58
+ # Sort by score (lowest distance = best match)
59
+ # print("πŸ” selected_chunks retrieved chunks:\n", selected_chunks)
60
+ selected_chunks.sort(key=lambda x: x[1])
61
+ final_chunks = [chunk for chunk, _ in selected_chunks[:top_k]]
62
+
63
+ print("πŸ” Final retrieved chunks:\n", final_chunks)
64
+ return "\n\n".join(final_chunks)
65
 
66
  # === Load QA model ===
67
  qa_pipeline = pipeline("question-answering", model="deepset/tinyroberta-squad2")