colbert-acl / knn_db_access.py
davidheineman's picture
improve readme
f9ad19d
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from openai_embed import QueryEmbedder
OPENAI = QueryEmbedder()
USER = "test"
SERVER = "dbbackend.c9tcfpp"
with open('.mongodb-secret', 'r') as f: PASS = f.read()
class MongoDBAccess:
def __init__(self) -> None:
self.uri = f"mongodb+srv://{USER}:{PASS}@{SERVER}.mongodb.net/?retryWrites=true&w=majority&appName=DBBackend"
self.client = MongoClient(self.uri, server_api=ServerApi('1'))
self.database = self.client["ColBERTPapers"]
self.col = self.database["papers"]
def ping(self) -> None:
try:
self.client.admin.command('ping')
print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
print(e)
def article_info_from_id_list(self, id_list:int):
query = {"id": {'$in': id_list}}
doc = self.col.find(query, {"id": 1, "title": 1, "year": 1, "author": 1, "abstract": 1})
res = []
for x in doc:
res.append(x)
return res
def vector_knn_search(self, query_embed, year, k=100):
pipeline = [
{
'$vectorSearch': {
'index': 'vector_index',
'path': 'embed',
'queryVector': query_embed,
'numCandidates': k,
'limit': k
}
},
{
"$project": {
'_id': 0,
'id': 1,
'title': 1,
'year': 1,
#'author': 1,
#'abstract': 1,
'score': {
'$meta': 'vectorSearchScore'
}
}
}
]
res = self.client["ColBERTPapers"]["papers"].aggregate(pipeline)
res_list = []
for i in res:
if (int(i['year']) >= year):
i['id'] = i['id'] - 1 # The MongoDB is 1-indexed
res_list.append(i)
return res_list
if __name__ == "__main__":
db = MongoDBAccess()
db.ping()
query_embed = OPENAI.embed_query("What is text simplification?")
results = db.vector_knn_search(query_embed, 1900)
# Sanity check to make sure it matches with the original collection of documents
import os, json
INDEX_NAME = os.getenv("INDEX_NAME", 'index_large')
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')
# Load abstracts as a collection
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
collection = json.load(f)
print(f'Returned {len(results)} results!')
_id = int(results[0]['id'])
print(results[0]['title'])
print(collection[_id])