Spaces:
Build error
Build error
import logging | |
import time | |
from core.rag.datasource.retrieval_service import RetrievalService | |
from core.rag.models.document import Document | |
from core.rag.retrieval.retrieval_methods import RetrievalMethod | |
from extensions.ext_database import db | |
from models.account import Account | |
from models.dataset import Dataset, DatasetQuery, DocumentSegment | |
default_retrieval_model = { | |
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |
"reranking_enable": False, | |
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |
"top_k": 2, | |
"score_threshold_enabled": False, | |
} | |
class HitTestingService: | |
def retrieve( | |
cls, | |
dataset: Dataset, | |
query: str, | |
account: Account, | |
retrieval_model: dict, | |
external_retrieval_model: dict, | |
limit: int = 10, | |
) -> dict: | |
if dataset.available_document_count == 0 or dataset.available_segment_count == 0: | |
return { | |
"query": { | |
"content": query, | |
"tsne_position": {"x": 0, "y": 0}, | |
}, | |
"records": [], | |
} | |
start = time.perf_counter() | |
# get retrieval model , if the model is not setting , using default | |
if not retrieval_model: | |
retrieval_model = dataset.retrieval_model or default_retrieval_model | |
all_documents = RetrievalService.retrieve( | |
retrieval_method=retrieval_model.get("search_method", "semantic_search"), | |
dataset_id=dataset.id, | |
query=cls.escape_query_for_search(query), | |
top_k=retrieval_model.get("top_k", 2), | |
score_threshold=retrieval_model.get("score_threshold", 0.0) | |
if retrieval_model["score_threshold_enabled"] | |
else 0.0, | |
reranking_model=retrieval_model.get("reranking_model", None) | |
if retrieval_model["reranking_enable"] | |
else None, | |
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |
weights=retrieval_model.get("weights", None), | |
) | |
end = time.perf_counter() | |
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | |
dataset_query = DatasetQuery( | |
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id | |
) | |
db.session.add(dataset_query) | |
db.session.commit() | |
return cls.compact_retrieve_response(dataset, query, all_documents) | |
def external_retrieve( | |
cls, | |
dataset: Dataset, | |
query: str, | |
account: Account, | |
external_retrieval_model: dict, | |
) -> dict: | |
if dataset.provider != "external": | |
return { | |
"query": {"content": query}, | |
"records": [], | |
} | |
start = time.perf_counter() | |
all_documents = RetrievalService.external_retrieve( | |
dataset_id=dataset.id, | |
query=cls.escape_query_for_search(query), | |
external_retrieval_model=external_retrieval_model, | |
) | |
end = time.perf_counter() | |
logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds") | |
dataset_query = DatasetQuery( | |
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id | |
) | |
db.session.add(dataset_query) | |
db.session.commit() | |
return cls.compact_external_retrieve_response(dataset, query, all_documents) | |
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): | |
records = [] | |
for document in documents: | |
index_node_id = document.metadata["doc_id"] | |
segment = ( | |
db.session.query(DocumentSegment) | |
.filter( | |
DocumentSegment.dataset_id == dataset.id, | |
DocumentSegment.enabled == True, | |
DocumentSegment.status == "completed", | |
DocumentSegment.index_node_id == index_node_id, | |
) | |
.first() | |
) | |
if not segment: | |
continue | |
record = { | |
"segment": segment, | |
"score": document.metadata.get("score", None), | |
} | |
records.append(record) | |
return { | |
"query": { | |
"content": query, | |
}, | |
"records": records, | |
} | |
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): | |
records = [] | |
if dataset.provider == "external": | |
for document in documents: | |
record = { | |
"content": document.get("content", None), | |
"title": document.get("title", None), | |
"score": document.get("score", None), | |
"metadata": document.get("metadata", None), | |
} | |
records.append(record) | |
return { | |
"query": { | |
"content": query, | |
}, | |
"records": records, | |
} | |
def hit_testing_args_check(cls, args): | |
query = args["query"] | |
if not query or len(query) > 250: | |
raise ValueError("Query is required and cannot exceed 250 characters") | |
def escape_query_for_search(query: str) -> str: | |
return query.replace('"', '\\"') | |