File size: 1,810 Bytes
7fe3ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af26a9d
7fe3ab0
af26a9d
7fe3ab0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import logging
import lancedb
import os
from pathlib import Path
from sentence_transformers import SentenceTransformer
import openai

from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', max_length=512)

def rerank_documents(query, documents):
  scores = cross_encoder.predict([(query,d) for d in documents])
  return [pair[1] for pair in sorted(zip(scores, documents), reverse=True)]

EMB_MODEL_NAME = ""
DB_TABLE_NAME = ""

# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Enable multiple retrievers
retrievers = {}

import tiktoken

def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
  """Returns the number of tokens in a text string."""
  encoding = tiktoken.get_encoding(encoding_name)
  num_tokens = len(encoding.encode(string))
  return num_tokens

def trim(text, length = 8190):
  text = ' '.join(text.split()).replace('<|endoftext|>','')
  while num_tokens_from_string(text) > length:
    text = ' '.join(text.split()[:-10])
  return text

def openai_embedding(text, key = None):
  client = openai.OpenAI(
      api_key=key,
  )
  trimmed = trim(text)
  rs = client.embeddings.create(input=[trimmed], model="text-embedding-ada-002")
  return rs.data[0].embedding

minilm = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
mpnet = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
retrievers['MiniLM'] = lambda t, key: minilm.encode(t)
retrievers['mpnet'] = lambda t, key: mpnet.encode(t)
retrievers['OpenAI'] = openai_embedding

# db
db_uri = os.path.join(Path(__file__).parents[1], ".lancedb")
db = lancedb.connect(db_uri)
tables = {}
for table_name in db.table_names():
  tables[table_name] = db.open_table(table_name)