import logging import lancedb import os from pathlib import Path from sentence_transformers import SentenceTransformer import openai from sentence_transformers import CrossEncoder cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', max_length=512) def rerank_documents(query, documents): scores = cross_encoder.predict([(query,d) for d in documents]) return [pair[1] for pair in sorted(zip(scores, documents), reverse=True)] EMB_MODEL_NAME = "" DB_TABLE_NAME = "" # Setting up the logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Enable multiple retrievers retrievers = {} import tiktoken def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: """Returns the number of tokens in a text string.""" encoding = tiktoken.get_encoding(encoding_name) num_tokens = len(encoding.encode(string)) return num_tokens def trim(text, length = 8190): text = ' '.join(text.split()).replace('<|endoftext|>','') while num_tokens_from_string(text) > length: text = ' '.join(text.split()[:-10]) return text def openai_embedding(text, key = None): client = openai.OpenAI( api_key=key, ) trimmed = trim(text) rs = client.embeddings.create(input=[trimmed], model="text-embedding-ada-002") return rs.data[0].embedding minilm = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') mpnet = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') retrievers['MiniLM'] = lambda t, key: minilm.encode(t) retrievers['mpnet'] = lambda t, key: mpnet.encode(t) retrievers['OpenAI'] = openai_embedding # db db_uri = os.path.join(Path(__file__).parents[1], ".lancedb") db = lancedb.connect(db_uri) tables = {} for table_name in db.table_names(): tables[table_name] = db.open_table(table_name)