paulokewunmi commited on
Commit
f440070
1 Parent(s): ffa8f17

Change vector db to pinecone

Browse files
Files changed (5) hide show
  1. app.py +2 -1
  2. requirements.txt +1 -2
  3. src/document_utils_v2.py +151 -0
  4. src/wiki_search.py +37 -79
  5. src/wiki_search_v2.py +162 -0
app.py CHANGED
@@ -65,13 +65,14 @@ with gr.Blocks(theme=custom_theme) as demo:
65
  "Hausa",
66
  ],
67
  label="Filter results based on language",
 
68
  )
69
 
70
  with gr.Row():
71
  with gr.Column():
72
  user_query = gr.Text(
73
  label="Enter query here",
74
- placeholder="Search through all your documents",
75
  )
76
 
77
  num_search_results = gr.Slider(
 
65
  "Hausa",
66
  ],
67
  label="Filter results based on language",
68
+ value = "Yoruba"
69
  )
70
 
71
  with gr.Row():
72
  with gr.Column():
73
  user_query = gr.Text(
74
  label="Enter query here",
75
+ placeholder="Search through all your study materials",
76
  )
77
 
78
  num_search_results = gr.Slider(
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  cohere
2
- qdrant_client==0.11.0
3
  gradio
4
  langchain
5
  black
6
- python-dotenv
 
1
  cohere
 
2
  gradio
3
  langchain
4
  black
5
+ "pinecone-client[grpc]"
src/document_utils_v2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import pandas as pd
5
+ from typing import List
6
+
7
+ import cohere
8
+ from langchain.embeddings.cohere import CohereEmbeddings
9
+ from langchain.llms import Cohere
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.vectorstores import Qdrant
12
+ from langchain.chains.question_answering import load_qa_chain
13
+
14
+ sys.path.append(os.path.abspath('..'))
15
+
16
+ from src.constants import SUMMARIZATION_MODEL, EXAMPLES_FILE_PATH
17
+
18
+
19
+
20
+ QDRANT_HOST = os.environ.get("QDRANT_HOST")
21
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
22
+ COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
23
+
24
+
25
+ def replace_text(text):
26
+ if text.startswith("The answer is "):
27
+ text = text.replace("The answer is ", "", 1)
28
+ return text
29
+
30
+
31
+ def summarize(
32
+ document: str,
33
+ summary_length: str,
34
+ summary_format: str,
35
+ extractiveness: str = "high",
36
+ temperature: float = 0.6,
37
+ ) -> str:
38
+ """
39
+ Generates a summary for the input document using Cohere's summarize API.
40
+ Args:
41
+ document (`str`):
42
+ The document given by the user for which summary must be generated.
43
+ summary_length (`str`):
44
+ A value such as 'short', 'medium', 'long' indicating the length of the summary.
45
+ summary_format (`str`):
46
+ This indicates whether the generated summary should be in 'paragraph' format or 'bullets'.
47
+ extractiveness (`str`, *optional*, defaults to 'high'):
48
+ A value such as 'low', 'medium', 'high' indicating how close the generated summary should be in meaning to the original text.
49
+ temperature (`str`):
50
+ This controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output.
51
+ Returns:
52
+ generated_summary (`str`):
53
+ The generated summary from the summarization model.
54
+ """
55
+
56
+ summary_response = cohere.Client(COHERE_API_KEY).summarize(
57
+ text=document,
58
+ length=summary_length,
59
+ format=summary_format,
60
+ model=SUMMARIZATION_MODEL,
61
+ extractiveness=extractiveness,
62
+ temperature=temperature,
63
+ )
64
+ generated_summary = summary_response.summary
65
+ return generated_summary
66
+
67
+
68
+ def question_answer(input_document: str, history: List) -> str:
69
+ """
70
+ Generates an appropriate answer for the question asked by the user based on the input document.
71
+ Args:
72
+ input_document (`str`):
73
+ The document given by the user for which summary must be generated.
74
+ history (`List[List[str,str]]`):
75
+ A list made up of pairs of input question asked by the user & corresponding generated answers. It is used to keep track of the history of the chat between the user and the model.
76
+ Returns:
77
+ answer (`str`):
78
+ The generated answer corresponding to the input question and document received from the user.
79
+ """
80
+ context = input_document
81
+ # The last element of the `history` list contains the most recent question asked by the user whose answer needs to be generated.
82
+ question = history[-1][0]
83
+ word_list = context.split()
84
+ # texts = [context[k : k + 256] for k in range(0, len(context.split()), 256)]
85
+ texts = [" ".join(word_list[k : k + 256]) for k in range(0, len(word_list), 256)]
86
+
87
+ # print(texts)
88
+
89
+ embeddings = CohereEmbeddings(
90
+ model="multilingual-22-12", cohere_api_key=COHERE_API_KEY
91
+ )
92
+ context_index = Qdrant.from_texts(
93
+ texts, embeddings, url=QDRANT_HOST, api_key=QDRANT_API_KEY
94
+ )
95
+
96
+ prompt_template = """Text: {context}
97
+ Question: {question}
98
+ Answer the question based on the text provided. If the text doesn't contain the answer, reply that the answer is not available."""
99
+
100
+ PROMPT = PromptTemplate(
101
+ template=prompt_template, input_variables=["context", "question"]
102
+ )
103
+
104
+ # Generate the answer given the context
105
+ chain = load_qa_chain(
106
+ Cohere(
107
+ model="command-xlarge-nightly", temperature=0, cohere_api_key=COHERE_API_KEY
108
+ ),
109
+ chain_type="stuff",
110
+ prompt=PROMPT,
111
+ )
112
+ relevant_context = context_index.similarity_search(question)
113
+ answer = chain.run(input_documents=relevant_context, question=question)
114
+ answer = answer.replace("\n", "").replace("Answer:", "")
115
+ answer = replace_text(answer)
116
+ return answer
117
+
118
+ def generate_questions(input_document: str) -> str:
119
+ generated_response = cohere.Client(COHERE_API_KEY).generate(
120
+ prompt = f"Give me 5 different questions to test understanding of the following text provided. Here's the provided text: {input_document}. Now what is Questions 1 to 5 ?:",
121
+ max_tokens = 200,
122
+ temperature = 0.55
123
+ )
124
+ # prompt = f"Generate 5 different quiz questions to test the understanding of the following text. Here's the provided text: {input_document}. Whats Questions 1 to 5 of the quiz ?:"
125
+ # print(prompt)
126
+ return generated_response.generations[0].text
127
+
128
+
129
+ def load_science():
130
+ examples_df = pd.read_csv(EXAMPLES_FILE_PATH)
131
+ science_doc = examples_df["doc"].iloc[0]
132
+ sample_question = examples_df["question"].iloc[0]
133
+ return science_doc, sample_question
134
+
135
+
136
+ def load_history():
137
+ examples_df = pd.read_csv(EXAMPLES_FILE_PATH)
138
+ history_doc = examples_df["doc"].iloc[1]
139
+ sample_question = examples_df["question"].iloc[1]
140
+ return history_doc, sample_question
141
+
142
+
143
+ if __name__ == "__main__":
144
+ with open('sample_text.txt', 'r') as file:
145
+ text = file.read()
146
+ # summary = summarize(text, summary_length="short", summary_format="bullets")
147
+ # print(summary)
148
+ # answer = question_answer(text, [["what is photosynthesis", None]])
149
+ # print(answer)
150
+ question = question_answer(text, ["Whats photosynthesis"])
151
+ print(question)
src/wiki_search.py CHANGED
@@ -1,14 +1,10 @@
1
  import os
2
  import cohere
3
  from typing import List
 
4
 
5
- from qdrant_client import QdrantClient
6
- from qdrant_client import models
7
-
8
-
9
- # load environment variables
10
- QDRANT_HOST = os.environ.get("QDRANT_HOST")
11
- QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
12
  COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
13
 
14
  MODEL_NAME = "multilingual-22-12"
@@ -17,12 +13,6 @@ COLLECTION = "wiki-embed"
17
  # create qdrant and cohere client
18
  cohere_client = cohere.Client(COHERE_API_KEY)
19
 
20
- qdrant_client = QdrantClient(
21
- host=QDRANT_HOST,
22
- api_key=QDRANT_API_KEY,
23
- port = 443,
24
- )
25
-
26
  def embed_user_query(user_query):
27
 
28
  embeddings = cohere_client.embed(
@@ -36,10 +26,13 @@ def embed_user_query(user_query):
36
  def search_wiki_for_query(
37
  query_embedding,
38
  num_results = 3,
39
- user_query= "",
40
  languages = [],
41
- match_text = None,
42
  ):
 
 
 
 
 
43
  filters = []
44
 
45
  language_mapping = {
@@ -49,78 +42,45 @@ def search_wiki_for_query(
49
  "Hause": "ha",
50
  }
51
 
 
 
52
  # prepare filters to narrow down search results
53
  # if the `match_text` list is not empty then create filter to find exact matching text in the documents
54
- if match_text:
55
- filters.append(
56
- models.FieldCondition(
57
- key="text",
58
- match=models.MatchText(text=user_query),
59
- )
60
- )
61
-
62
- # filter documents based on language before performing search:
63
- if languages:
64
- for lang in languages:
65
- filters.append(
66
- models.FieldCondition(
67
- key="lang",
68
- match=models.MatchValue(
69
- value=language_mapping[lang],
70
- ),
71
- )
72
- )
73
-
74
- # perform search and get results
75
- results = qdrant_client.search(
76
- collection_name=COLLECTION,
77
- query_filter=models.Filter(should=filters),
78
- search_params=models.SearchParams(hnsw_ef=128, exact=False),
79
- query_vector=query_embedding,
80
- limit=num_results,
81
  )
82
- return results
 
 
 
83
 
84
 
85
  def cross_lingual_document_search(
86
  user_input: str, num_results: int, languages, text_match
87
  ) -> List:
88
- """
89
- Wrapper function for performing search on the collection of documents for the given user query.
90
- Prepares query embedding, retrieves search results, checks if expected number of search results are being returned.
91
- Args:
92
- user_input (`str`):
93
- The user input based on which search will be performed.
94
- num_results (`str`):
95
- The number of expected search results.
96
- languages (`str`):
97
- The list of languages based on which search results must be filtered.
98
- text_match (`str`):
99
- A field based on which it is decided whether to perform full-text-match while performing search.
100
- Returns:
101
- final_results (`List[str]`):
102
- A list containing the final search results corresponding to the given user input.
103
- """
104
  # create an embedding for the input query
105
  query_embedding, _ = embed_user_query(user_input)
106
 
107
  # retrieve search results
108
- result = search_wiki_for_query(
109
  query_embedding,
110
  num_results,
111
- user_input,
112
  languages,
113
- text_match,
114
  )
115
- final_results = [result[i].payload["text"] for i in range(len(result))]
116
-
117
- # check if number of search results obtained (i.e. `final_results`) is matching with number of expected search results i.e. `num_results`
118
- if num_results > len(final_results):
119
- remaining_inputs = num_results - len(final_results)
120
  for input in range(remaining_inputs):
121
- final_results.append("")
122
 
123
- return final_results
124
 
125
  def document_source(
126
  user_input: str, num_results: int, languages, text_match
@@ -128,22 +88,20 @@ def document_source(
128
  query_embedding, _ = embed_user_query(user_input)
129
 
130
  # retrieve search results
131
- result = search_wiki_for_query(
132
  query_embedding,
133
  num_results,
134
- user_input,
135
  languages,
136
- text_match,
137
  )
138
- sources = [result[i].payload["url"] for i in range(len(result))]
139
-
140
- # check if number of search results obtained (i.e. `final_results`) is matching with number of expected search results i.e. `num_results`
141
- if num_results > len(sources):
142
- remaining_inputs = num_results - len(sources)
143
  for input in range(remaining_inputs):
144
- sources.append("")
145
 
146
- return sources
147
 
148
 
149
  def translate_search_result():
 
1
  import os
2
  import cohere
3
  from typing import List
4
+ import pinecone
5
 
6
+ PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
7
+ PINECONE_ENV = os.environ.get("PINECONE_ENV")
 
 
 
 
 
8
  COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
9
 
10
  MODEL_NAME = "multilingual-22-12"
 
13
  # create qdrant and cohere client
14
  cohere_client = cohere.Client(COHERE_API_KEY)
15
 
 
 
 
 
 
 
16
  def embed_user_query(user_query):
17
 
18
  embeddings = cohere_client.embed(
 
26
  def search_wiki_for_query(
27
  query_embedding,
28
  num_results = 3,
 
29
  languages = [],
 
30
  ):
31
+ pinecone.init(api_key= PINECONE_API_KEY,
32
+ environment=PINECONE_ENV)
33
+ index = pinecone.GRPCIndex(COLLECTION)
34
+
35
+
36
  filters = []
37
 
38
  language_mapping = {
 
42
  "Hause": "ha",
43
  }
44
 
45
+ index.query(query_embedding, top_k=num_results, include_metadata=True)
46
+
47
  # prepare filters to narrow down search results
48
  # if the `match_text` list is not empty then create filter to find exact matching text in the documents
49
+ query_results = index.query(
50
+ top_k=3,
51
+ include_metadata=True,
52
+ vector= query_embedding,
53
+ filter={
54
+ 'lang': {'$in': [language_mapping[lang] for lang in languages]}
55
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
+
58
+ metadata = [record["metadata"] for record in query_results["matches"]]
59
+
60
+ return metadata
61
 
62
 
63
  def cross_lingual_document_search(
64
  user_input: str, num_results: int, languages, text_match
65
  ) -> List:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # create an embedding for the input query
67
  query_embedding, _ = embed_user_query(user_input)
68
 
69
  # retrieve search results
70
+ metadata = search_wiki_for_query(
71
  query_embedding,
72
  num_results,
 
73
  languages,
 
74
  )
75
+
76
+ results = [result['title']+"\n"+result['text'] for result in metadata]
77
+
78
+ if num_results > len(results):
79
+ remaining_inputs = num_results - len(results)
80
  for input in range(remaining_inputs):
81
+ results.append("")
82
 
83
+ return results
84
 
85
  def document_source(
86
  user_input: str, num_results: int, languages, text_match
 
88
  query_embedding, _ = embed_user_query(user_input)
89
 
90
  # retrieve search results
91
+ metadata = search_wiki_for_query(
92
  query_embedding,
93
  num_results,
 
94
  languages,
 
95
  )
96
+
97
+ results = [result['url'] for result in metadata]
98
+
99
+ if num_results > len(results):
100
+ remaining_inputs = num_results - len(results)
101
  for input in range(remaining_inputs):
102
+ results.append("")
103
 
104
+ return results
105
 
106
 
107
  def translate_search_result():
src/wiki_search_v2.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cohere
3
+ from typing import List
4
+
5
+ from qdrant_client import QdrantClient
6
+ from qdrant_client import models
7
+
8
+
9
+ # load environment variables
10
+ QDRANT_HOST = os.environ.get("QDRANT_HOST")
11
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
12
+ COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
13
+
14
+ MODEL_NAME = "multilingual-22-12"
15
+ COLLECTION = "wiki-embed"
16
+
17
+ # create qdrant and cohere client
18
+ cohere_client = cohere.Client(COHERE_API_KEY)
19
+
20
+ qdrant_client = QdrantClient(
21
+ host=QDRANT_HOST,
22
+ api_key=QDRANT_API_KEY,
23
+ port = 443,
24
+ )
25
+
26
+ def embed_user_query(user_query):
27
+
28
+ embeddings = cohere_client.embed(
29
+ texts=[user_query],
30
+ model=MODEL_NAME,
31
+ )
32
+ query_embedding = embeddings.embeddings[0]
33
+ return query_embedding, user_query
34
+
35
+
36
+ def search_wiki_for_query(
37
+ query_embedding,
38
+ num_results = 3,
39
+ user_query= "",
40
+ languages = [],
41
+ match_text = None,
42
+ ):
43
+ filters = []
44
+
45
+ language_mapping = {
46
+ "English": "en",
47
+ "Yoruba": "yo",
48
+ "Igbo": "ig",
49
+ "Hause": "ha",
50
+ }
51
+
52
+ # prepare filters to narrow down search results
53
+ # if the `match_text` list is not empty then create filter to find exact matching text in the documents
54
+ if match_text:
55
+ filters.append(
56
+ models.FieldCondition(
57
+ key="text",
58
+ match=models.MatchText(text=user_query),
59
+ )
60
+ )
61
+
62
+ # filter documents based on language before performing search:
63
+ if languages:
64
+ for lang in languages:
65
+ filters.append(
66
+ models.FieldCondition(
67
+ key="lang",
68
+ match=models.MatchValue(
69
+ value=language_mapping[lang],
70
+ ),
71
+ )
72
+ )
73
+
74
+ # perform search and get results
75
+ results = qdrant_client.search(
76
+ collection_name=COLLECTION,
77
+ query_filter=models.Filter(should=filters),
78
+ search_params=models.SearchParams(hnsw_ef=128, exact=False),
79
+ query_vector=query_embedding,
80
+ limit=num_results,
81
+ )
82
+ return results
83
+
84
+
85
+ def cross_lingual_document_search(
86
+ user_input: str, num_results: int, languages, text_match
87
+ ) -> List:
88
+ """
89
+ Wrapper function for performing search on the collection of documents for the given user query.
90
+ Prepares query embedding, retrieves search results, checks if expected number of search results are being returned.
91
+ Args:
92
+ user_input (`str`):
93
+ The user input based on which search will be performed.
94
+ num_results (`str`):
95
+ The number of expected search results.
96
+ languages (`str`):
97
+ The list of languages based on which search results must be filtered.
98
+ text_match (`str`):
99
+ A field based on which it is decided whether to perform full-text-match while performing search.
100
+ Returns:
101
+ final_results (`List[str]`):
102
+ A list containing the final search results corresponding to the given user input.
103
+ """
104
+ # create an embedding for the input query
105
+ query_embedding, _ = embed_user_query(user_input)
106
+
107
+ # retrieve search results
108
+ result = search_wiki_for_query(
109
+ query_embedding,
110
+ num_results,
111
+ user_input,
112
+ languages,
113
+ text_match,
114
+ )
115
+ final_results = [result[i].payload["text"] for i in range(len(result))]
116
+
117
+ # check if number of search results obtained (i.e. `final_results`) is matching with number of expected search results i.e. `num_results`
118
+ if num_results > len(final_results):
119
+ remaining_inputs = num_results - len(final_results)
120
+ for input in range(remaining_inputs):
121
+ final_results.append("")
122
+
123
+ return final_results
124
+
125
+ def document_source(
126
+ user_input: str, num_results: int, languages, text_match
127
+ ) -> List:
128
+ query_embedding, _ = embed_user_query(user_input)
129
+
130
+ # retrieve search results
131
+ result = search_wiki_for_query(
132
+ query_embedding,
133
+ num_results,
134
+ user_input,
135
+ languages,
136
+ text_match,
137
+ )
138
+ sources = [result[i].payload["url"] for i in range(len(result))]
139
+
140
+ # check if number of search results obtained (i.e. `final_results`) is matching with number of expected search results i.e. `num_results`
141
+ if num_results > len(sources):
142
+ remaining_inputs = num_results - len(sources)
143
+ for input in range(remaining_inputs):
144
+ sources.append("")
145
+
146
+ return sources
147
+
148
+
149
+ def translate_search_result():
150
+ pass
151
+
152
+ if __name__ == "__main__":
153
+ # query_embedding, user_query = embed_user_query("Who is the president of Nigeria")
154
+ # result = search_wiki_for_query(query_embedding,user_query=user_query)
155
+
156
+ # for item in result:
157
+ # print(item.payload["url"])
158
+ result = cross_lingual_document_search("Who is the president of Nigeria",
159
+ num_results=3,
160
+ languages=["Yoruba"],
161
+ text_match=False)
162
+ print(result, len(result))