|
import imp |
|
import os |
|
|
|
from datasets import DatasetDict |
|
from elasticsearch import Elasticsearch |
|
from elastic_transport import ConnectionError |
|
from dotenv import load_dotenv |
|
|
|
from src.retrievers.base_retriever import RetrieveType, Retriever |
|
from src.utils.log import logger |
|
from src.utils.timing import timeit |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
class ESRetriever(Retriever): |
|
def __init__(self, paragraphs: DatasetDict) -> None: |
|
self.paragraphs = paragraphs["train"] |
|
|
|
es_host = os.getenv("ELASTIC_HOST", "localhost") |
|
es_password = os.getenv("ELASTIC_PASSWORD") |
|
es_username = os.getenv("ELASTIC_USERNAME") |
|
|
|
self.client = Elasticsearch( |
|
hosts=[es_host], |
|
http_auth=(es_username, es_password), |
|
ca_certs="./http_ca.crt") |
|
|
|
try: |
|
self.client.info() |
|
except ConnectionError: |
|
logger.error("Could not connect to ElasticSearch. " + |
|
"Make sure it is running. Exiting now...") |
|
exit() |
|
|
|
if self.client.indices.exists(index="paragraphs"): |
|
self.paragraphs.load_elasticsearch_index( |
|
"paragraphs", es_index_name="paragraphs", |
|
es_client=self.client) |
|
else: |
|
logger.info(f"Creating index 'paragraphs' on {es_host}") |
|
self.paragraphs.add_elasticsearch_index(column="text", |
|
index_name="paragraphs", |
|
es_index_name="paragraphs", |
|
es_client=self.client) |
|
|
|
def retrieve(self, query: str, k: int = 5) -> RetrieveType: |
|
return self.paragraphs.get_nearest_examples("paragraphs", query, k) |
|
|