Spaces:
Runtime error
Runtime error
paulokewunmi
commited on
Commit
•
f440070
1
Parent(s):
ffa8f17
Change vector db to pinecone
Browse files- app.py +2 -1
- requirements.txt +1 -2
- src/document_utils_v2.py +151 -0
- src/wiki_search.py +37 -79
- src/wiki_search_v2.py +162 -0
app.py
CHANGED
@@ -65,13 +65,14 @@ with gr.Blocks(theme=custom_theme) as demo:
|
|
65 |
"Hausa",
|
66 |
],
|
67 |
label="Filter results based on language",
|
|
|
68 |
)
|
69 |
|
70 |
with gr.Row():
|
71 |
with gr.Column():
|
72 |
user_query = gr.Text(
|
73 |
label="Enter query here",
|
74 |
-
placeholder="Search through all your
|
75 |
)
|
76 |
|
77 |
num_search_results = gr.Slider(
|
|
|
65 |
"Hausa",
|
66 |
],
|
67 |
label="Filter results based on language",
|
68 |
+
value = "Yoruba"
|
69 |
)
|
70 |
|
71 |
with gr.Row():
|
72 |
with gr.Column():
|
73 |
user_query = gr.Text(
|
74 |
label="Enter query here",
|
75 |
+
placeholder="Search through all your study materials",
|
76 |
)
|
77 |
|
78 |
num_search_results = gr.Slider(
|
requirements.txt
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
cohere
|
2 |
-
qdrant_client==0.11.0
|
3 |
gradio
|
4 |
langchain
|
5 |
black
|
6 |
-
|
|
|
1 |
cohere
|
|
|
2 |
gradio
|
3 |
langchain
|
4 |
black
|
5 |
+
"pinecone-client[grpc]"
|
src/document_utils_v2.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import cohere
|
8 |
+
from langchain.embeddings.cohere import CohereEmbeddings
|
9 |
+
from langchain.llms import Cohere
|
10 |
+
from langchain.prompts import PromptTemplate
|
11 |
+
from langchain.vectorstores import Qdrant
|
12 |
+
from langchain.chains.question_answering import load_qa_chain
|
13 |
+
|
14 |
+
sys.path.append(os.path.abspath('..'))
|
15 |
+
|
16 |
+
from src.constants import SUMMARIZATION_MODEL, EXAMPLES_FILE_PATH
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
QDRANT_HOST = os.environ.get("QDRANT_HOST")
|
21 |
+
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
22 |
+
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
|
23 |
+
|
24 |
+
|
25 |
+
def replace_text(text):
|
26 |
+
if text.startswith("The answer is "):
|
27 |
+
text = text.replace("The answer is ", "", 1)
|
28 |
+
return text
|
29 |
+
|
30 |
+
|
31 |
+
def summarize(
|
32 |
+
document: str,
|
33 |
+
summary_length: str,
|
34 |
+
summary_format: str,
|
35 |
+
extractiveness: str = "high",
|
36 |
+
temperature: float = 0.6,
|
37 |
+
) -> str:
|
38 |
+
"""
|
39 |
+
Generates a summary for the input document using Cohere's summarize API.
|
40 |
+
Args:
|
41 |
+
document (`str`):
|
42 |
+
The document given by the user for which summary must be generated.
|
43 |
+
summary_length (`str`):
|
44 |
+
A value such as 'short', 'medium', 'long' indicating the length of the summary.
|
45 |
+
summary_format (`str`):
|
46 |
+
This indicates whether the generated summary should be in 'paragraph' format or 'bullets'.
|
47 |
+
extractiveness (`str`, *optional*, defaults to 'high'):
|
48 |
+
A value such as 'low', 'medium', 'high' indicating how close the generated summary should be in meaning to the original text.
|
49 |
+
temperature (`str`):
|
50 |
+
This controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output.
|
51 |
+
Returns:
|
52 |
+
generated_summary (`str`):
|
53 |
+
The generated summary from the summarization model.
|
54 |
+
"""
|
55 |
+
|
56 |
+
summary_response = cohere.Client(COHERE_API_KEY).summarize(
|
57 |
+
text=document,
|
58 |
+
length=summary_length,
|
59 |
+
format=summary_format,
|
60 |
+
model=SUMMARIZATION_MODEL,
|
61 |
+
extractiveness=extractiveness,
|
62 |
+
temperature=temperature,
|
63 |
+
)
|
64 |
+
generated_summary = summary_response.summary
|
65 |
+
return generated_summary
|
66 |
+
|
67 |
+
|
68 |
+
def question_answer(input_document: str, history: List) -> str:
|
69 |
+
"""
|
70 |
+
Generates an appropriate answer for the question asked by the user based on the input document.
|
71 |
+
Args:
|
72 |
+
input_document (`str`):
|
73 |
+
The document given by the user for which summary must be generated.
|
74 |
+
history (`List[List[str,str]]`):
|
75 |
+
A list made up of pairs of input question asked by the user & corresponding generated answers. It is used to keep track of the history of the chat between the user and the model.
|
76 |
+
Returns:
|
77 |
+
answer (`str`):
|
78 |
+
The generated answer corresponding to the input question and document received from the user.
|
79 |
+
"""
|
80 |
+
context = input_document
|
81 |
+
# The last element of the `history` list contains the most recent question asked by the user whose answer needs to be generated.
|
82 |
+
question = history[-1][0]
|
83 |
+
word_list = context.split()
|
84 |
+
# texts = [context[k : k + 256] for k in range(0, len(context.split()), 256)]
|
85 |
+
texts = [" ".join(word_list[k : k + 256]) for k in range(0, len(word_list), 256)]
|
86 |
+
|
87 |
+
# print(texts)
|
88 |
+
|
89 |
+
embeddings = CohereEmbeddings(
|
90 |
+
model="multilingual-22-12", cohere_api_key=COHERE_API_KEY
|
91 |
+
)
|
92 |
+
context_index = Qdrant.from_texts(
|
93 |
+
texts, embeddings, url=QDRANT_HOST, api_key=QDRANT_API_KEY
|
94 |
+
)
|
95 |
+
|
96 |
+
prompt_template = """Text: {context}
|
97 |
+
Question: {question}
|
98 |
+
Answer the question based on the text provided. If the text doesn't contain the answer, reply that the answer is not available."""
|
99 |
+
|
100 |
+
PROMPT = PromptTemplate(
|
101 |
+
template=prompt_template, input_variables=["context", "question"]
|
102 |
+
)
|
103 |
+
|
104 |
+
# Generate the answer given the context
|
105 |
+
chain = load_qa_chain(
|
106 |
+
Cohere(
|
107 |
+
model="command-xlarge-nightly", temperature=0, cohere_api_key=COHERE_API_KEY
|
108 |
+
),
|
109 |
+
chain_type="stuff",
|
110 |
+
prompt=PROMPT,
|
111 |
+
)
|
112 |
+
relevant_context = context_index.similarity_search(question)
|
113 |
+
answer = chain.run(input_documents=relevant_context, question=question)
|
114 |
+
answer = answer.replace("\n", "").replace("Answer:", "")
|
115 |
+
answer = replace_text(answer)
|
116 |
+
return answer
|
117 |
+
|
118 |
+
def generate_questions(input_document: str) -> str:
|
119 |
+
generated_response = cohere.Client(COHERE_API_KEY).generate(
|
120 |
+
prompt = f"Give me 5 different questions to test understanding of the following text provided. Here's the provided text: {input_document}. Now what is Questions 1 to 5 ?:",
|
121 |
+
max_tokens = 200,
|
122 |
+
temperature = 0.55
|
123 |
+
)
|
124 |
+
# prompt = f"Generate 5 different quiz questions to test the understanding of the following text. Here's the provided text: {input_document}. Whats Questions 1 to 5 of the quiz ?:"
|
125 |
+
# print(prompt)
|
126 |
+
return generated_response.generations[0].text
|
127 |
+
|
128 |
+
|
129 |
+
def load_science():
|
130 |
+
examples_df = pd.read_csv(EXAMPLES_FILE_PATH)
|
131 |
+
science_doc = examples_df["doc"].iloc[0]
|
132 |
+
sample_question = examples_df["question"].iloc[0]
|
133 |
+
return science_doc, sample_question
|
134 |
+
|
135 |
+
|
136 |
+
def load_history():
|
137 |
+
examples_df = pd.read_csv(EXAMPLES_FILE_PATH)
|
138 |
+
history_doc = examples_df["doc"].iloc[1]
|
139 |
+
sample_question = examples_df["question"].iloc[1]
|
140 |
+
return history_doc, sample_question
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
with open('sample_text.txt', 'r') as file:
|
145 |
+
text = file.read()
|
146 |
+
# summary = summarize(text, summary_length="short", summary_format="bullets")
|
147 |
+
# print(summary)
|
148 |
+
# answer = question_answer(text, [["what is photosynthesis", None]])
|
149 |
+
# print(answer)
|
150 |
+
question = question_answer(text, ["Whats photosynthesis"])
|
151 |
+
print(question)
|
src/wiki_search.py
CHANGED
@@ -1,14 +1,10 @@
|
|
1 |
import os
|
2 |
import cohere
|
3 |
from typing import List
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
# load environment variables
|
10 |
-
QDRANT_HOST = os.environ.get("QDRANT_HOST")
|
11 |
-
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
12 |
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
|
13 |
|
14 |
MODEL_NAME = "multilingual-22-12"
|
@@ -17,12 +13,6 @@ COLLECTION = "wiki-embed"
|
|
17 |
# create qdrant and cohere client
|
18 |
cohere_client = cohere.Client(COHERE_API_KEY)
|
19 |
|
20 |
-
qdrant_client = QdrantClient(
|
21 |
-
host=QDRANT_HOST,
|
22 |
-
api_key=QDRANT_API_KEY,
|
23 |
-
port = 443,
|
24 |
-
)
|
25 |
-
|
26 |
def embed_user_query(user_query):
|
27 |
|
28 |
embeddings = cohere_client.embed(
|
@@ -36,10 +26,13 @@ def embed_user_query(user_query):
|
|
36 |
def search_wiki_for_query(
|
37 |
query_embedding,
|
38 |
num_results = 3,
|
39 |
-
user_query= "",
|
40 |
languages = [],
|
41 |
-
match_text = None,
|
42 |
):
|
|
|
|
|
|
|
|
|
|
|
43 |
filters = []
|
44 |
|
45 |
language_mapping = {
|
@@ -49,78 +42,45 @@ def search_wiki_for_query(
|
|
49 |
"Hause": "ha",
|
50 |
}
|
51 |
|
|
|
|
|
52 |
# prepare filters to narrow down search results
|
53 |
# if the `match_text` list is not empty then create filter to find exact matching text in the documents
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
# filter documents based on language before performing search:
|
63 |
-
if languages:
|
64 |
-
for lang in languages:
|
65 |
-
filters.append(
|
66 |
-
models.FieldCondition(
|
67 |
-
key="lang",
|
68 |
-
match=models.MatchValue(
|
69 |
-
value=language_mapping[lang],
|
70 |
-
),
|
71 |
-
)
|
72 |
-
)
|
73 |
-
|
74 |
-
# perform search and get results
|
75 |
-
results = qdrant_client.search(
|
76 |
-
collection_name=COLLECTION,
|
77 |
-
query_filter=models.Filter(should=filters),
|
78 |
-
search_params=models.SearchParams(hnsw_ef=128, exact=False),
|
79 |
-
query_vector=query_embedding,
|
80 |
-
limit=num_results,
|
81 |
)
|
82 |
-
|
|
|
|
|
|
|
83 |
|
84 |
|
85 |
def cross_lingual_document_search(
|
86 |
user_input: str, num_results: int, languages, text_match
|
87 |
) -> List:
|
88 |
-
"""
|
89 |
-
Wrapper function for performing search on the collection of documents for the given user query.
|
90 |
-
Prepares query embedding, retrieves search results, checks if expected number of search results are being returned.
|
91 |
-
Args:
|
92 |
-
user_input (`str`):
|
93 |
-
The user input based on which search will be performed.
|
94 |
-
num_results (`str`):
|
95 |
-
The number of expected search results.
|
96 |
-
languages (`str`):
|
97 |
-
The list of languages based on which search results must be filtered.
|
98 |
-
text_match (`str`):
|
99 |
-
A field based on which it is decided whether to perform full-text-match while performing search.
|
100 |
-
Returns:
|
101 |
-
final_results (`List[str]`):
|
102 |
-
A list containing the final search results corresponding to the given user input.
|
103 |
-
"""
|
104 |
# create an embedding for the input query
|
105 |
query_embedding, _ = embed_user_query(user_input)
|
106 |
|
107 |
# retrieve search results
|
108 |
-
|
109 |
query_embedding,
|
110 |
num_results,
|
111 |
-
user_input,
|
112 |
languages,
|
113 |
-
text_match,
|
114 |
)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
if num_results > len(
|
119 |
-
remaining_inputs = num_results - len(
|
120 |
for input in range(remaining_inputs):
|
121 |
-
|
122 |
|
123 |
-
return
|
124 |
|
125 |
def document_source(
|
126 |
user_input: str, num_results: int, languages, text_match
|
@@ -128,22 +88,20 @@ def document_source(
|
|
128 |
query_embedding, _ = embed_user_query(user_input)
|
129 |
|
130 |
# retrieve search results
|
131 |
-
|
132 |
query_embedding,
|
133 |
num_results,
|
134 |
-
user_input,
|
135 |
languages,
|
136 |
-
text_match,
|
137 |
)
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
if num_results > len(
|
142 |
-
remaining_inputs = num_results - len(
|
143 |
for input in range(remaining_inputs):
|
144 |
-
|
145 |
|
146 |
-
return
|
147 |
|
148 |
|
149 |
def translate_search_result():
|
|
|
1 |
import os
|
2 |
import cohere
|
3 |
from typing import List
|
4 |
+
import pinecone
|
5 |
|
6 |
+
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
|
7 |
+
PINECONE_ENV = os.environ.get("PINECONE_ENV")
|
|
|
|
|
|
|
|
|
|
|
8 |
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
|
9 |
|
10 |
MODEL_NAME = "multilingual-22-12"
|
|
|
13 |
# create qdrant and cohere client
|
14 |
cohere_client = cohere.Client(COHERE_API_KEY)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def embed_user_query(user_query):
|
17 |
|
18 |
embeddings = cohere_client.embed(
|
|
|
26 |
def search_wiki_for_query(
|
27 |
query_embedding,
|
28 |
num_results = 3,
|
|
|
29 |
languages = [],
|
|
|
30 |
):
|
31 |
+
pinecone.init(api_key= PINECONE_API_KEY,
|
32 |
+
environment=PINECONE_ENV)
|
33 |
+
index = pinecone.GRPCIndex(COLLECTION)
|
34 |
+
|
35 |
+
|
36 |
filters = []
|
37 |
|
38 |
language_mapping = {
|
|
|
42 |
"Hause": "ha",
|
43 |
}
|
44 |
|
45 |
+
index.query(query_embedding, top_k=num_results, include_metadata=True)
|
46 |
+
|
47 |
# prepare filters to narrow down search results
|
48 |
# if the `match_text` list is not empty then create filter to find exact matching text in the documents
|
49 |
+
query_results = index.query(
|
50 |
+
top_k=3,
|
51 |
+
include_metadata=True,
|
52 |
+
vector= query_embedding,
|
53 |
+
filter={
|
54 |
+
'lang': {'$in': [language_mapping[lang] for lang in languages]}
|
55 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
)
|
57 |
+
|
58 |
+
metadata = [record["metadata"] for record in query_results["matches"]]
|
59 |
+
|
60 |
+
return metadata
|
61 |
|
62 |
|
63 |
def cross_lingual_document_search(
|
64 |
user_input: str, num_results: int, languages, text_match
|
65 |
) -> List:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# create an embedding for the input query
|
67 |
query_embedding, _ = embed_user_query(user_input)
|
68 |
|
69 |
# retrieve search results
|
70 |
+
metadata = search_wiki_for_query(
|
71 |
query_embedding,
|
72 |
num_results,
|
|
|
73 |
languages,
|
|
|
74 |
)
|
75 |
+
|
76 |
+
results = [result['title']+"\n"+result['text'] for result in metadata]
|
77 |
+
|
78 |
+
if num_results > len(results):
|
79 |
+
remaining_inputs = num_results - len(results)
|
80 |
for input in range(remaining_inputs):
|
81 |
+
results.append("")
|
82 |
|
83 |
+
return results
|
84 |
|
85 |
def document_source(
|
86 |
user_input: str, num_results: int, languages, text_match
|
|
|
88 |
query_embedding, _ = embed_user_query(user_input)
|
89 |
|
90 |
# retrieve search results
|
91 |
+
metadata = search_wiki_for_query(
|
92 |
query_embedding,
|
93 |
num_results,
|
|
|
94 |
languages,
|
|
|
95 |
)
|
96 |
+
|
97 |
+
results = [result['url'] for result in metadata]
|
98 |
+
|
99 |
+
if num_results > len(results):
|
100 |
+
remaining_inputs = num_results - len(results)
|
101 |
for input in range(remaining_inputs):
|
102 |
+
results.append("")
|
103 |
|
104 |
+
return results
|
105 |
|
106 |
|
107 |
def translate_search_result():
|
src/wiki_search_v2.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cohere
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from qdrant_client import QdrantClient
|
6 |
+
from qdrant_client import models
|
7 |
+
|
8 |
+
|
9 |
+
# load environment variables
|
10 |
+
QDRANT_HOST = os.environ.get("QDRANT_HOST")
|
11 |
+
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
12 |
+
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
|
13 |
+
|
14 |
+
MODEL_NAME = "multilingual-22-12"
|
15 |
+
COLLECTION = "wiki-embed"
|
16 |
+
|
17 |
+
# create qdrant and cohere client
|
18 |
+
cohere_client = cohere.Client(COHERE_API_KEY)
|
19 |
+
|
20 |
+
qdrant_client = QdrantClient(
|
21 |
+
host=QDRANT_HOST,
|
22 |
+
api_key=QDRANT_API_KEY,
|
23 |
+
port = 443,
|
24 |
+
)
|
25 |
+
|
26 |
+
def embed_user_query(user_query):
|
27 |
+
|
28 |
+
embeddings = cohere_client.embed(
|
29 |
+
texts=[user_query],
|
30 |
+
model=MODEL_NAME,
|
31 |
+
)
|
32 |
+
query_embedding = embeddings.embeddings[0]
|
33 |
+
return query_embedding, user_query
|
34 |
+
|
35 |
+
|
36 |
+
def search_wiki_for_query(
|
37 |
+
query_embedding,
|
38 |
+
num_results = 3,
|
39 |
+
user_query= "",
|
40 |
+
languages = [],
|
41 |
+
match_text = None,
|
42 |
+
):
|
43 |
+
filters = []
|
44 |
+
|
45 |
+
language_mapping = {
|
46 |
+
"English": "en",
|
47 |
+
"Yoruba": "yo",
|
48 |
+
"Igbo": "ig",
|
49 |
+
"Hause": "ha",
|
50 |
+
}
|
51 |
+
|
52 |
+
# prepare filters to narrow down search results
|
53 |
+
# if the `match_text` list is not empty then create filter to find exact matching text in the documents
|
54 |
+
if match_text:
|
55 |
+
filters.append(
|
56 |
+
models.FieldCondition(
|
57 |
+
key="text",
|
58 |
+
match=models.MatchText(text=user_query),
|
59 |
+
)
|
60 |
+
)
|
61 |
+
|
62 |
+
# filter documents based on language before performing search:
|
63 |
+
if languages:
|
64 |
+
for lang in languages:
|
65 |
+
filters.append(
|
66 |
+
models.FieldCondition(
|
67 |
+
key="lang",
|
68 |
+
match=models.MatchValue(
|
69 |
+
value=language_mapping[lang],
|
70 |
+
),
|
71 |
+
)
|
72 |
+
)
|
73 |
+
|
74 |
+
# perform search and get results
|
75 |
+
results = qdrant_client.search(
|
76 |
+
collection_name=COLLECTION,
|
77 |
+
query_filter=models.Filter(should=filters),
|
78 |
+
search_params=models.SearchParams(hnsw_ef=128, exact=False),
|
79 |
+
query_vector=query_embedding,
|
80 |
+
limit=num_results,
|
81 |
+
)
|
82 |
+
return results
|
83 |
+
|
84 |
+
|
85 |
+
def cross_lingual_document_search(
|
86 |
+
user_input: str, num_results: int, languages, text_match
|
87 |
+
) -> List:
|
88 |
+
"""
|
89 |
+
Wrapper function for performing search on the collection of documents for the given user query.
|
90 |
+
Prepares query embedding, retrieves search results, checks if expected number of search results are being returned.
|
91 |
+
Args:
|
92 |
+
user_input (`str`):
|
93 |
+
The user input based on which search will be performed.
|
94 |
+
num_results (`str`):
|
95 |
+
The number of expected search results.
|
96 |
+
languages (`str`):
|
97 |
+
The list of languages based on which search results must be filtered.
|
98 |
+
text_match (`str`):
|
99 |
+
A field based on which it is decided whether to perform full-text-match while performing search.
|
100 |
+
Returns:
|
101 |
+
final_results (`List[str]`):
|
102 |
+
A list containing the final search results corresponding to the given user input.
|
103 |
+
"""
|
104 |
+
# create an embedding for the input query
|
105 |
+
query_embedding, _ = embed_user_query(user_input)
|
106 |
+
|
107 |
+
# retrieve search results
|
108 |
+
result = search_wiki_for_query(
|
109 |
+
query_embedding,
|
110 |
+
num_results,
|
111 |
+
user_input,
|
112 |
+
languages,
|
113 |
+
text_match,
|
114 |
+
)
|
115 |
+
final_results = [result[i].payload["text"] for i in range(len(result))]
|
116 |
+
|
117 |
+
# check if number of search results obtained (i.e. `final_results`) is matching with number of expected search results i.e. `num_results`
|
118 |
+
if num_results > len(final_results):
|
119 |
+
remaining_inputs = num_results - len(final_results)
|
120 |
+
for input in range(remaining_inputs):
|
121 |
+
final_results.append("")
|
122 |
+
|
123 |
+
return final_results
|
124 |
+
|
125 |
+
def document_source(
|
126 |
+
user_input: str, num_results: int, languages, text_match
|
127 |
+
) -> List:
|
128 |
+
query_embedding, _ = embed_user_query(user_input)
|
129 |
+
|
130 |
+
# retrieve search results
|
131 |
+
result = search_wiki_for_query(
|
132 |
+
query_embedding,
|
133 |
+
num_results,
|
134 |
+
user_input,
|
135 |
+
languages,
|
136 |
+
text_match,
|
137 |
+
)
|
138 |
+
sources = [result[i].payload["url"] for i in range(len(result))]
|
139 |
+
|
140 |
+
# check if number of search results obtained (i.e. `final_results`) is matching with number of expected search results i.e. `num_results`
|
141 |
+
if num_results > len(sources):
|
142 |
+
remaining_inputs = num_results - len(sources)
|
143 |
+
for input in range(remaining_inputs):
|
144 |
+
sources.append("")
|
145 |
+
|
146 |
+
return sources
|
147 |
+
|
148 |
+
|
149 |
+
def translate_search_result():
|
150 |
+
pass
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
# query_embedding, user_query = embed_user_query("Who is the president of Nigeria")
|
154 |
+
# result = search_wiki_for_query(query_embedding,user_query=user_query)
|
155 |
+
|
156 |
+
# for item in result:
|
157 |
+
# print(item.payload["url"])
|
158 |
+
result = cross_lingual_document_search("Who is the president of Nigeria",
|
159 |
+
num_results=3,
|
160 |
+
languages=["Yoruba"],
|
161 |
+
text_match=False)
|
162 |
+
print(result, len(result))
|