manfredmichael commited on
Commit
966108f
1 Parent(s): 462639d

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ multilingual-e5-large filter=lfs diff=lfs merge=lfs -text
37
+ multilingual-e5-large/* filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dotenv import load_dotenv
3
+ import json
4
+ import os, time
5
+ import uuid
6
+
7
+ from retrieval_pipeline import get_retriever, get_compression_retriever
8
+ import benchmark
9
+
10
+
11
+ def get_result(query, compression_retriever):
12
+ t0 = time.time()
13
+ retrieved_chunks = compression_retriever.get_relevant_documents(query)
14
+ latency = time.time() - t0
15
+ return retrieved_chunks, latency
16
+
17
+ st.set_page_config(
18
+ layout="wide",
19
+ page_title="Retrieval Demo"
20
+ )
21
+
22
+ def setup():
23
+ load_dotenv()
24
+ ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
25
+
26
+ retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
27
+ compression_retriever = get_compression_retriever(retriever)
28
+ return compression_retriever
29
+
30
+
31
+ def main():
32
+ st.title("Part 3: Search")
33
+ # st.sidebar.write("According to the Model Size 👇")
34
+ # menu = ["Nano", "Small", "Medium", "Large"]
35
+ # choice = st.sidebar.selectbox("Choose", menu)
36
+
37
+ st.sidebar.info("""
38
+ **Model Options:**
39
+ - **Nano**: ~4MB, blazing fast model with competitive performance (ranking precision).
40
+ - **Small**: ~34MB, slightly slower with the best performance (ranking precision).
41
+ - **Medium**: ~110MB, slower model with the best zero-shot performance (ranking precision).
42
+ - **Large**: ~150MB, slower model with competitive performance (ranking precision) for 100+ languages.
43
+ """)
44
+
45
+ with st.spinner('Setting up...'):
46
+ compression_retriever = setup()
47
+
48
+ with st.expander("Tech Stack Used"):
49
+ st.markdown("""
50
+ **Flash Rank**: Ultra-lite & Super-fast Python library for search & retrieval re-ranking.
51
+
52
+ - **Ultra-lite**: No heavy dependencies. Runs on CPU with a tiny ~4MB reranking model.
53
+ - **Super-fast**: Speed depends on the number of tokens in passages and query, plus model depth.
54
+ - **Cost-efficient**: Ideal for serverless deployments with low memory and time requirements.
55
+ - **Based on State-of-the-Art Cross-encoders**: Includes models like ms-marco-TinyBERT-L-2-v2 (default), ms-marco-MiniLM-L-12-v2, rank-T5-flan, and ms-marco-MultiBERT-L-12.
56
+ - **Sleek Models for Efficiency**: Designed for minimal overhead in user-facing scenarios.
57
+
58
+ _Flash Rank is tailored for scenarios requiring efficient and effective reranking, balancing performance with resource usage._
59
+ """)
60
+
61
+
62
+ with st.form(key='input_form'):
63
+ query_input = st.text_area("Query Input")
64
+ # context_input = st.text_area("Context Input")
65
+ submit_button = st.form_submit_button(label='Retrieve')
66
+
67
+ if submit_button:
68
+ st.session_state.submitted = True
69
+
70
+ if 'submitted' in st.session_state:
71
+ with st.spinner('Processing...'):
72
+ result, latency = get_result(query_input, compression_retriever)
73
+ st.subheader("Please find the retrieved documents below 👇")
74
+ st.write("latency:", latency, " ms")
75
+ st.json(result)
76
+
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
benchmark.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from retrieval_pipeline import get_relevant_documents
3
+ import tqdm, time
4
+
5
+
6
+ TOP_N = 3
7
+
8
+ def get_benchmark_result(path, retriever):
9
+ df = pd.read_csv(path)
10
+ retrieval_result = []
11
+ query_result = [[] for i in range(TOP_N)]
12
+ retrieval_latency = []
13
+
14
+ # j = 0
15
+ for i, row in tqdm.tqdm(df.iterrows()):
16
+ # j+=1
17
+ query = row['query']
18
+ target = row['body']
19
+
20
+
21
+ t0 = time.time()
22
+ results = retriever.get_relevant_documents(query)
23
+ t = time.time() - t0
24
+ retrieval_latency.append(t)
25
+
26
+ result_content = [result.page_content for result in results]
27
+ # results_content = get_relevant_documents(query, retriever, top_k=5)
28
+
29
+ for i, text in enumerate(result_content):
30
+ query_result[i] = text
31
+
32
+ if target in result_content:
33
+ retrieval_result.append("Success")
34
+ else:
35
+ retrieval_result.append("Failed")
36
+ # if j>20:
37
+ # break
38
+
39
+ df["retrieval_result"] = retrieval_result
40
+ df["retrieval_latency"] = retrieval_latency
41
+ for i in range(TOP_N):
42
+ df[f'q{i+1}'] = query_result[i]
43
+ df.to_csv('benchmark_result q3 topk 5.csv')
44
+ print(df['retrieval_result'].value_counts())
45
+ print(df['retrieval_result'].value_counts()/ len(df))
46
+
47
+
main.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import json
3
+ import os
4
+ import uuid
5
+
6
+ from retrieval_pipeline import get_retriever, get_compression_retriever
7
+ import benchmark
8
+
9
+ load_dotenv()
10
+ ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
11
+ # HUGGINGFACE_KEY = os.getenv('HUGGINGFACE_KEY')
12
+
13
+ os.environ["ES_ENDPOINT"] = ELASTICSEARCH_URL
14
+ print(ELASTICSEARCH_URL)
15
+
16
+ if __name__ == "__main__":
17
+ retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
18
+ compression_retriever = get_compression_retriever(retriever)
19
+ retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
20
+ print(retrieved_chunks)
21
+
22
+ # retrieved_chunks = retriever.get_relevant_documents('Gunung Semeru')
23
+ # print(retrieved_chunks)
24
+
25
+ benchmark.get_benchmark_result("benchmark-reranker.csv", retriever=compression_retriever)
26
+
27
+ # for i in range(100):
28
+ # query = input("query: ")
29
+ # retrieved_chunks = retriever.get_relevant_documents(query)
30
+ # print("Result:")
31
+ # for r in retrieved_chunks:
32
+ # print(r.page_content[:50])
33
+ # print()
retrieval_pipeline/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from retrieval_pipeline.main import *
2
+ from retrieval_pipeline.hybrid_search import *
retrieval_pipeline/hybrid_search.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.vectorstores import VectorStoreRetriever
2
+ from langchain_community.retrievers import ElasticSearchBM25Retriever
3
+ from langchain_community.vectorstores.elastic_vector_search import ElasticVectorSearch
4
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun
5
+ from langchain_core.retrievers import BaseRetriever
6
+ from langchain_core.documents import Document
7
+ import elasticsearch
8
+
9
+
10
+ from typing import Optional, List
11
+
12
+
13
+ class HybridRetriever(BaseRetriever):
14
+ dense_db: ElasticVectorSearch
15
+ dense_retriever: VectorStoreRetriever
16
+ sparse_retriever: ElasticSearchBM25Retriever
17
+ index_dense: str
18
+ index_sparse: str
19
+ top_k_dense: int
20
+ top_k_sparse: int
21
+
22
+ is_training: bool = False
23
+
24
+ @classmethod
25
+ def create(
26
+ cls, dense_db, dense_retriever, sparse_retriever, index_dense, index_sparse, top_k_dense, top_k_sparse
27
+ ):
28
+
29
+ return cls(
30
+ dense_db=dense_db,
31
+ dense_retriever=dense_retriever,
32
+ sparse_retriever=sparse_retriever,
33
+ index_dense=index_dense,
34
+ index_sparse=index_sparse,
35
+ top_k_dense=top_k_dense,
36
+ top_k_sparse=top_k_sparse,
37
+ )
38
+
39
+ def reset_indices(self):
40
+ result = self.dense_db.client.indices.delete(
41
+ index=self.index_dense,
42
+ ignore_unavailable=True,
43
+ allow_no_indices=True,
44
+ )
45
+
46
+
47
+ logging.info('dense_db delete:', result)
48
+
49
+ result = self.sparse_retriever.client.indices.delete(
50
+ index=self.index_sparse,
51
+ ignore_unavailable=True,
52
+ allow_no_indices=True,
53
+ )
54
+
55
+ logging.info('sparse_retriever delete:', result)
56
+
57
+ def add_documents(self, documents, batch_size=25):
58
+ for i in range(0, len(documents), batch_size):
59
+ print('batch', i)
60
+ dense_batch = documents[i:i + batch_size]
61
+ sparse_batch = [doc.page_content for doc in dense_batch]
62
+ self.dense_retriever.add_documents(dense_batch)
63
+ self.sparse_retriever.add_texts(sparse_batch)
64
+
65
+ def _get_relevant_documents(self, query: str, **kwargs):
66
+ dense_results = self.dense_retriever.get_relevant_documents(query)[:self.top_k_dense]
67
+ sparse_results = self.sparse_retriever.get_relevant_documents(query)[:self.top_k_sparse]
68
+
69
+ # Combine results (you'll need a strategy here)
70
+ combined_results = dense_results + sparse_results
71
+ # result_text = [doc.page_content for doc in combined_results]
72
+
73
+ # reranked_result = rerank.rerank(query, documents=result_text, model="rerank-lite-1", top_k=self.top_k_dense+self.top_k_sparse)
74
+ # reranked_result = sorted(reranked_result.results, key=lambda result: result.index)
75
+
76
+ # Create LangChain Documents
77
+ documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
78
+ # documents = [Document(page_content=doc.page_content, metadata=doc.metadata, relevance_score=result.relevance_score) for result, doc in zip(reranked_result, combined_results)]
79
+ return documents
80
+
81
+ async def aget_relevant_documents(self, query: str):
82
+ raise NotImplementedError
83
+
84
+ def get_dense_db(elasticsearch_url, index_dense, embeddings):
85
+ dense_db = ElasticVectorSearch(
86
+ elasticsearch_url=elasticsearch_url,
87
+ index_name=index_dense,
88
+ embedding=embeddings,
89
+ )
90
+ return dense_db
91
+
92
+ def get_sparse_retriever(elasticsearch_url, index_sparse):
93
+ sparse_retriever = ElasticSearchBM25Retriever(client=elasticsearch.Elasticsearch(elasticsearch_url),
94
+ index_name=index_sparse)
95
+ return sparse_retriever
retrieval_pipeline/main.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.vectorstores import ElasticVectorSearch
2
+ from langchain.llms import OpenAI, HuggingFaceHub
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from retrieval_pipeline.hybrid_search import HybridRetriever, get_dense_db, get_sparse_retriever
5
+ from retrieval_pipeline.utils import get_hybrid_indexes
6
+
7
+ from langchain.retrievers import ContextualCompressionRetriever
8
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
9
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
10
+
11
+ import logging
12
+ import tqdm
13
+
14
+
15
+ def get_compression_retriever(retriever):
16
+ model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
17
+ compressor = CrossEncoderReranker(model=model, top_n=3)
18
+ compression_retriever = ContextualCompressionRetriever(
19
+ base_compressor=compressor, base_retriever=retriever
20
+ )
21
+ return compression_retriever
22
+
23
+ # Embedding Models Loader
24
+ def get_huggingface_embeddings(model_name):
25
+ logging.info(f"Loading Huggingface Embedding")
26
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
27
+ return embeddings
28
+
29
+ def get_vectorstore(index_name, embeddings, elasticsearch_url=None):
30
+ logging.info(f"Loading vectorstore")
31
+
32
+ index_dense, index_sparse = get_hybrid_indexes(index_name)
33
+
34
+ dense_db = get_dense_db(elasticsearch_url, index_dense, embeddings)
35
+ dense_retriever = dense_db.as_retriever()
36
+
37
+ sparse_retriever = get_sparse_retriever(elasticsearch_url, index_sparse)
38
+
39
+ hybrid_retriever = HybridRetriever(
40
+ dense_db=dense_db,
41
+ dense_retriever=dense_retriever,
42
+ sparse_retriever=sparse_retriever,
43
+ index_dense=index_dense,
44
+ index_sparse=index_sparse,
45
+ top_k_dense=2,
46
+ top_k_sparse=3
47
+ )
48
+
49
+ # db = ElasticVectorSearch(
50
+ # elasticsearch_url=elasticsearch_url,
51
+ # index_name=index_name,
52
+ # embedding=embeddings,
53
+ # )
54
+ return hybrid_retriever
55
+
56
+ def get_retriever(index, elasticsearch_url):
57
+ # cache.init(pre_embedding_func=get_msg_func)
58
+ # cache.set_openai_key(openai_api_key)
59
+
60
+ embeddings = get_huggingface_embeddings(model_name="multilingual-e5-large")
61
+
62
+ # llm = get_openai_llm(
63
+ # model_name=model_name, temperature=0, api_key=model_api_key
64
+ # )
65
+ # embeddings = get_openai_embeddings(embedding_api_key, embedding_name)
66
+
67
+ # question_generator = load_question_generator(llm)
68
+ # answer_generator = load_answer_generator(llm, company=model_config['company_name'], tone=model_config['tone'], additional_instructions=model_config['additional_instructions'])
69
+
70
+ retriever = get_vectorstore(
71
+ index,
72
+ embeddings=embeddings,
73
+ elasticsearch_url=elasticsearch_url,
74
+ )
75
+
76
+ # if history:
77
+ # qa = get_conversational_chain(retriever, question_generator, answer_generator)
78
+ # else:
79
+ # qa = get_retrieval_chain(retriever, answer_generator)
80
+
81
+ # chain = CustomLLMChain(
82
+ # chain=qa,
83
+ # model_name=llm.model_name,
84
+ # use_history=history
85
+ # )
86
+ #
87
+ #
88
+ return retriever
89
+
90
+ def get_relevant_documents(query, retriever, top_k):
91
+ results = retriever.get_relevant_documents(query)
92
+ passages = [{
93
+ "id": i,
94
+ "text": result.page_content
95
+ } for i, result in enumerate(results)]
96
+
97
+ reranked_result = ranker.rerank(RerankRequest(query=query, passages=passages))
98
+ return reranked_result
retrieval_pipeline/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def get_hybrid_indexes(index_name):
2
+ index_dense = f'{index_name}-dense'
3
+ index_sparse = f'{index_name}-sparse'
4
+
5
+ return index_dense, index_sparse