|
from pymongo.mongo_client import MongoClient |
|
from pymongo.server_api import ServerApi |
|
|
|
from openai_embed import QueryEmbedder |
|
OPENAI = QueryEmbedder() |
|
|
|
|
|
USER = "test" |
|
SERVER = "dbbackend.c9tcfpp" |
|
with open('.mongodb-secret', 'r') as f: PASS = f.read() |
|
|
|
|
|
class MongoDBAccess: |
|
def __init__(self) -> None: |
|
self.uri = f"mongodb+srv://{USER}:{PASS}@{SERVER}.mongodb.net/?retryWrites=true&w=majority&appName=DBBackend" |
|
self.client = MongoClient(self.uri, server_api=ServerApi('1')) |
|
self.database = self.client["ColBERTPapers"] |
|
self.col = self.database["papers"] |
|
|
|
def ping(self) -> None: |
|
try: |
|
self.client.admin.command('ping') |
|
print("Pinged your deployment. You successfully connected to MongoDB!") |
|
except Exception as e: |
|
print(e) |
|
|
|
def article_info_from_id_list(self, id_list:int): |
|
query = {"id": {'$in': id_list}} |
|
doc = self.col.find(query, {"id": 1, "title": 1, "year": 1, "author": 1, "abstract": 1}) |
|
res = [] |
|
for x in doc: |
|
res.append(x) |
|
return res |
|
|
|
def vector_knn_search(self, query_embed, year, k=100): |
|
pipeline = [ |
|
{ |
|
'$vectorSearch': { |
|
'index': 'vector_index', |
|
'path': 'embed', |
|
'queryVector': query_embed, |
|
'numCandidates': k, |
|
'limit': k |
|
} |
|
}, |
|
{ |
|
"$project": { |
|
'_id': 0, |
|
'id': 1, |
|
'title': 1, |
|
'year': 1, |
|
|
|
|
|
'score': { |
|
'$meta': 'vectorSearchScore' |
|
} |
|
} |
|
} |
|
] |
|
res = self.client["ColBERTPapers"]["papers"].aggregate(pipeline) |
|
res_list = [] |
|
for i in res: |
|
if (int(i['year']) >= year): |
|
i['id'] = i['id'] - 1 |
|
res_list.append(i) |
|
return res_list |
|
|
|
|
|
if __name__ == "__main__": |
|
db = MongoDBAccess() |
|
db.ping() |
|
|
|
query_embed = OPENAI.embed_query("What is text simplification?") |
|
results = db.vector_knn_search(query_embed, 1900) |
|
|
|
|
|
|
|
import os, json |
|
INDEX_NAME = os.getenv("INDEX_NAME", 'index_large') |
|
INDEX_ROOT = os.getenv("INDEX_ROOT", '.') |
|
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME) |
|
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json') |
|
|
|
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f: |
|
collection = json.load(f) |
|
|
|
print(f'Returned {len(results)} results!') |
|
|
|
_id = int(results[0]['id']) |
|
print(results[0]['title']) |
|
print(collection[_id]) |
|
|