buster-dev / buster /utils.py
hbertrand's picture
PR: retriever interface (#77)
06bca0c unverified
import os
import urllib.request
from typing import Type
from buster.documents import DocumentsDB, DocumentsManager, DocumentsPickle
from buster.retriever import PickleRetriever, Retriever, SQLiteRetriever
PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
def get_file_extension(filepath: str) -> str:
return os.path.splitext(filepath)[1]
def download_db(db_url: str, output_dir: str):
os.makedirs(output_dir, exist_ok=True)
fname = os.path.join(output_dir, "documents.db")
if not os.path.exists(fname):
print(f"Downloading db file from {db_url} to {fname}...")
urllib.request.urlretrieve(db_url, fname)
print("Downloaded.")
else:
print("File already exists. Skipping.")
return fname
def get_documents_manager_from_extension(filepath: str) -> Type[DocumentsManager]:
ext = get_file_extension(filepath)
if ext in PICKLE_EXTENSIONS:
return DocumentsPickle
elif ext == ".db":
return DocumentsDB
else:
raise ValueError(f"Unsupported format: {ext}.")
def get_retriever_from_extension(filepath: str) -> Type[Retriever]:
ext = get_file_extension(filepath)
if ext in PICKLE_EXTENSIONS:
return PickleRetriever
elif ext == ".db":
return SQLiteRetriever
else:
raise ValueError(f"Unsupported format: {ext}.")