rag-gradio / backend /semantic_search.py
pavvloff's picture
Update backend/semantic_search.py
af26a9d
raw
history blame
1.81 kB
import logging
import lancedb
import os
from pathlib import Path
from sentence_transformers import SentenceTransformer
import openai
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', max_length=512)
def rerank_documents(query, documents):
scores = cross_encoder.predict([(query,d) for d in documents])
return [pair[1] for pair in sorted(zip(scores, documents), reverse=True)]
EMB_MODEL_NAME = ""
DB_TABLE_NAME = ""
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Enable multiple retrievers
retrievers = {}
import tiktoken
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def trim(text, length = 8190):
text = ' '.join(text.split()).replace('<|endoftext|>','')
while num_tokens_from_string(text) > length:
text = ' '.join(text.split()[:-10])
return text
def openai_embedding(text, key = None):
client = openai.OpenAI(
api_key=key,
)
trimmed = trim(text)
rs = client.embeddings.create(input=[trimmed], model="text-embedding-ada-002")
return rs.data[0].embedding
minilm = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
mpnet = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
retrievers['MiniLM'] = lambda t, key: minilm.encode(t)
retrievers['mpnet'] = lambda t, key: mpnet.encode(t)
retrievers['OpenAI'] = openai_embedding
# db
db_uri = os.path.join(Path(__file__).parents[1], ".lancedb")
db = lancedb.connect(db_uri)
tables = {}
for table_name in db.table_names():
tables[table_name] = db.open_table(table_name)