Spaces:
Runtime error
Runtime error
File size: 1,810 Bytes
7fe3ab0 af26a9d 7fe3ab0 af26a9d 7fe3ab0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import logging
import lancedb
import os
from pathlib import Path
from sentence_transformers import SentenceTransformer
import openai
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', max_length=512)
def rerank_documents(query, documents):
scores = cross_encoder.predict([(query,d) for d in documents])
return [pair[1] for pair in sorted(zip(scores, documents), reverse=True)]
EMB_MODEL_NAME = ""
DB_TABLE_NAME = ""
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Enable multiple retrievers
retrievers = {}
import tiktoken
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def trim(text, length = 8190):
text = ' '.join(text.split()).replace('<|endoftext|>','')
while num_tokens_from_string(text) > length:
text = ' '.join(text.split()[:-10])
return text
def openai_embedding(text, key = None):
client = openai.OpenAI(
api_key=key,
)
trimmed = trim(text)
rs = client.embeddings.create(input=[trimmed], model="text-embedding-ada-002")
return rs.data[0].embedding
minilm = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
mpnet = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
retrievers['MiniLM'] = lambda t, key: minilm.encode(t)
retrievers['mpnet'] = lambda t, key: mpnet.encode(t)
retrievers['OpenAI'] = openai_embedding
# db
db_uri = os.path.join(Path(__file__).parents[1], ".lancedb")
db = lancedb.connect(db_uri)
tables = {}
for table_name in db.table_names():
tables[table_name] = db.open_table(table_name)
|