Spaces:
Runtime error
Runtime error
Michael Bernovskiy
commited on
Commit
•
1268b93
1
Parent(s):
b4085f5
reranker added
Browse files- app.py +7 -2
- backend/reranker.py +10 -0
app.py
CHANGED
@@ -11,7 +11,7 @@ from jinja2 import Environment, FileSystemLoader
|
|
11 |
|
12 |
from backend.query_llm import generate_hf, generate_openai
|
13 |
from backend.semantic_search import retrieve
|
14 |
-
|
15 |
|
16 |
TOP_K = int(os.getenv("TOP_K", 4))
|
17 |
|
@@ -44,11 +44,16 @@ def bot(history, api_kind):
|
|
44 |
# Retrieve documents relevant to query
|
45 |
document_start = perf_counter()
|
46 |
|
47 |
-
documents = retrieve(query,
|
48 |
|
49 |
document_time = perf_counter() - document_start
|
50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
51 |
|
|
|
|
|
|
|
|
|
|
|
52 |
# Create Prompt
|
53 |
prompt = template.render(documents=documents, query=query)
|
54 |
prompt_html = template_html.render(documents=documents, query=query)
|
|
|
11 |
|
12 |
from backend.query_llm import generate_hf, generate_openai
|
13 |
from backend.semantic_search import retrieve
|
14 |
+
from study.backend.reranker import rerank
|
15 |
|
16 |
TOP_K = int(os.getenv("TOP_K", 4))
|
17 |
|
|
|
44 |
# Retrieve documents relevant to query
|
45 |
document_start = perf_counter()
|
46 |
|
47 |
+
documents = retrieve(query, 10)
|
48 |
|
49 |
document_time = perf_counter() - document_start
|
50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
51 |
|
52 |
+
rerank_start = perf_counter()
|
53 |
+
documents = rerank(query, documents, TOP_K)
|
54 |
+
rerank_time = perf_counter() - rerank_start
|
55 |
+
logger.info(f'Finished Reranking documents in {round(rerank_time, 2)} seconds...')
|
56 |
+
|
57 |
# Create Prompt
|
58 |
prompt = template.render(documents=documents, query=query)
|
59 |
prompt_html = template_html.render(documents=documents, query=query)
|
backend/reranker.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from FlagEmbedding import FlagReranker
|
2 |
+
|
3 |
+
reranker = FlagReranker('BAAI/bge-reranker-large',
|
4 |
+
use_fp16=True)
|
5 |
+
|
6 |
+
|
7 |
+
def rerank(query: str, documents: [str], k: int) -> [str]:
|
8 |
+
scores = reranker.compute_score([(query, document) for document in documents])
|
9 |
+
sorted_docs = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
|
10 |
+
return [sorted_docs[i] for i, _ in sorted_docs[:k]]
|