from pymongo.mongo_client import MongoClient from pymongo.server_api import ServerApi from openai_embed import QueryEmbedder OPENAI = QueryEmbedder() USER = "test" SERVER = "dbbackend.c9tcfpp" with open('.mongodb-secret', 'r') as f: PASS = f.read() class MongoDBAccess: def __init__(self) -> None: self.uri = f"mongodb+srv://{USER}:{PASS}@{SERVER}.mongodb.net/?retryWrites=true&w=majority&appName=DBBackend" self.client = MongoClient(self.uri, server_api=ServerApi('1')) self.database = self.client["ColBERTPapers"] self.col = self.database["papers"] def ping(self) -> None: try: self.client.admin.command('ping') print("Pinged your deployment. You successfully connected to MongoDB!") except Exception as e: print(e) def article_info_from_id_list(self, id_list:int): query = {"id": {'$in': id_list}} doc = self.col.find(query, {"id": 1, "title": 1, "year": 1, "author": 1, "abstract": 1}) res = [] for x in doc: res.append(x) return res def vector_knn_search(self, query_embed, year, k=100): pipeline = [ { '$vectorSearch': { 'index': 'vector_index', 'path': 'embed', 'queryVector': query_embed, 'numCandidates': k, 'limit': k } }, { "$project": { '_id': 0, 'id': 1, 'title': 1, 'year': 1, #'author': 1, #'abstract': 1, 'score': { '$meta': 'vectorSearchScore' } } } ] res = self.client["ColBERTPapers"]["papers"].aggregate(pipeline) res_list = [] for i in res: if (int(i['year']) >= year): i['id'] = i['id'] - 1 # The MongoDB is 1-indexed res_list.append(i) return res_list if __name__ == "__main__": db = MongoDBAccess() db.ping() query_embed = OPENAI.embed_query("What is text simplification?") results = db.vector_knn_search(query_embed, 1900) # Sanity check to make sure it matches with the original collection of documents import os, json INDEX_NAME = os.getenv("INDEX_NAME", 'index_large') INDEX_ROOT = os.getenv("INDEX_ROOT", '.') INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME) COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json') # Load abstracts as a collection with open(COLLECTION_PATH, 'r', encoding='utf-8') as f: collection = json.load(f) print(f'Returned {len(results)} results!') _id = int(results[0]['id']) print(results[0]['title']) print(collection[_id])