Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
from typing import Callable, Optional | |
from uuid import UUID | |
import numpy as np | |
from ntr_fileparser import ParsedDocument | |
from ntr_text_fragmentation import (EntitiesExtractor, EntityRepository, | |
InjectionBuilder, InMemoryEntityRepository, LinkerEntity) | |
from common.configuration import Configuration | |
from components.dbo.chunk_repository import ChunkRepository | |
from components.embedding_extraction import EmbeddingExtractor | |
from components.llm.deepinfra_api import DeepInfraApi | |
from components.search.appendices_chunker import APPENDICES_CHUNKER | |
from components.search.faiss_vector_search import FaissVectorSearch | |
from components.services.llm_config import LLMConfigService | |
logger = logging.getLogger(__name__) | |
class EntityService: | |
""" | |
Сервис для работы с сущностями. | |
Объединяет функциональность chunk_repository, destructurer, injection_builder и faiss_vector_search. | |
""" | |
def __init__( | |
self, | |
vectorizer: EmbeddingExtractor, | |
chunk_repository: ChunkRepository, | |
config: Configuration, | |
llm_api: DeepInfraApi, | |
llm_config_service: LLMConfigService, | |
) -> None: | |
""" | |
Инициализация сервиса. | |
Args: | |
vectorizer: Модель для извлечения эмбеддингов | |
chunk_repository: Репозиторий для работы с чанками | |
config: Конфигурация приложения | |
llm_api: Клиент для взаимодействия с LLM API | |
llm_config_service: Сервис для получения конфигурации LLM | |
""" | |
self.vectorizer = vectorizer | |
self.config = config | |
self.chunk_repository = chunk_repository | |
self.llm_api = llm_api | |
self.llm_config_service = llm_config_service | |
self.faiss_search = None | |
self.current_dataset_id = None | |
self.neighbors_max_distance = config.db_config.entities.neighbors_max_distance | |
self.max_entities_per_message = config.db_config.search.max_entities_per_message | |
self.max_entities_per_dialogue = ( | |
config.db_config.search.max_entities_per_dialogue | |
) | |
self.main_extractor = EntitiesExtractor( | |
strategy_name=config.db_config.entities.strategy_name, | |
strategy_params=config.db_config.entities.strategy_params, | |
process_tables=config.db_config.entities.process_tables, | |
) | |
self.appendices_extractor = EntitiesExtractor( | |
strategy_name=APPENDICES_CHUNKER, | |
strategy_params={ | |
"llm_api": self.llm_api, | |
"llm_config_service": self.llm_config_service, | |
}, | |
process_tables=False, | |
) | |
self._in_memory_cache: InMemoryEntityRepository = None | |
self._cached_dataset_id: int | None = None | |
def invalidate_cache(self) -> None: | |
"""Инвалидирует (удаляет) текущий кеш в памяти.""" | |
if self._in_memory_cache: | |
self._in_memory_cache = None | |
self._cached_dataset_id = None | |
else: | |
logger.info("In-memory кеш уже пуст. Ничего не делаем.") | |
def build_cache(self, dataset_id: int) -> None: | |
"""Строит кеш для указанного датасета.""" | |
all_entities = self.chunk_repository.get_all_entities_for_dataset(dataset_id) | |
in_memory_repo = InMemoryEntityRepository(entities=all_entities) | |
self._in_memory_cache = in_memory_repo | |
self._cached_dataset_id = dataset_id | |
async def build_or_rebuild_cache_async(self, dataset_id: int) -> None: | |
""" | |
Строит или перестраивает кеш для указанного датасета, удаляя предыдущий кеш. | |
""" | |
all_entities = await self.chunk_repository.get_all_entities_for_dataset_async(dataset_id) | |
if not all_entities: | |
logger.warning(f"No entities found for dataset {dataset_id}. Cache not built.") | |
self._in_memory_cache = None | |
self._cached_dataset_id = None | |
return | |
logger.info(f"Building new in-memory cache for dataset {dataset_id}") | |
in_memory_repo = InMemoryEntityRepository(entities=all_entities) | |
self._in_memory_cache = in_memory_repo | |
self._cached_dataset_id = dataset_id | |
logger.info(f"Cached {len(all_entities)} entities for dataset {dataset_id}") | |
def _get_repository_for_dataset(self, dataset_id: int) -> EntityRepository: | |
""" | |
Возвращает кешированный репозиторий, если он существует и соответствует | |
запрошенному dataset_id, иначе возвращает основной репозиторий ChunkRepository. | |
""" | |
# Проверяем совпадение ID с закешированным | |
if self._cached_dataset_id == dataset_id and self._in_memory_cache is not None: | |
return self._in_memory_cache | |
else: | |
# Логируем причину промаха кеша для диагностики | |
if not self._in_memory_cache: | |
logger.warning(f"Cache miss for dataset {dataset_id}: Cache is empty. Using ChunkRepository (DB).") | |
elif self._cached_dataset_id != dataset_id: | |
logger.warning(f"Cache miss for dataset {dataset_id}: Cache contains data for dataset {self._cached_dataset_id}. Using ChunkRepository (DB).") | |
else: # На случай непредвиденной ситуации | |
logger.warning(f"Cache miss for dataset {dataset_id}: Unknown reason. Using ChunkRepository (DB).") | |
return self.chunk_repository | |
def _ensure_faiss_initialized(self, dataset_id: int) -> None: | |
""" | |
Проверяет и при необходимости инициализирует или обновляет FAISS индекс. | |
Args: | |
dataset_id: ID датасета для инициализации | |
""" | |
# Переинициализируем FAISS, только если ID датасета изменился | |
if self.faiss_search is None or self.current_dataset_id != dataset_id: | |
logger.info(f'Initializing FAISS for dataset {dataset_id}') | |
entities, embeddings = self.chunk_repository.get_searching_entities( | |
dataset_id | |
) | |
if entities: | |
embeddings_dict = { | |
str(entity.id): embedding # Преобразуем UUID в строку для ключа | |
for entity, embedding in zip(entities, embeddings) | |
if embedding is not None | |
} | |
if embeddings_dict: # Проверяем, что есть хотя бы один эмбеддинг | |
self.faiss_search = FaissVectorSearch( | |
self.vectorizer, | |
embeddings_dict, | |
) | |
self.current_dataset_id = dataset_id | |
logger.info( | |
f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings' | |
) | |
else: | |
logger.warning( | |
f'No valid embeddings found for dataset {dataset_id}' | |
) | |
self.faiss_search = None | |
self.current_dataset_id = None | |
else: | |
logger.warning(f'No entities found for dataset {dataset_id}') | |
self.faiss_search = None | |
self.current_dataset_id = None | |
async def process_document( | |
self, | |
document: ParsedDocument, | |
dataset_id: int, | |
progress_callback: Optional[Callable] = None, | |
) -> None: | |
""" | |
Асинхронная обработка документа: разбиение на чанки и сохранение в базу. | |
Args: | |
document: Документ для обработки | |
dataset_id: ID датасета | |
progress_callback: Функция для отслеживания прогресса | |
""" | |
logger.info(f"Processing document {document.name} for dataset {dataset_id}") | |
# Определяем экстрактор в зависимости от имени документа | |
if 'Приложение' in document.name: | |
entities = await self.appendices_extractor.extract_async(document) | |
else: | |
entities = await self.main_extractor.extract_async(document) | |
# Фильтруем сущности для поиска | |
filtering_entities = [ | |
entity for entity in entities if entity.in_search_text is not None | |
] | |
filtering_texts = [entity.in_search_text for entity in filtering_entities] | |
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback) | |
# Собираем словарь эмбеддингов только для найденных сущностей | |
embeddings_dict = {} | |
if embeddings is not None: | |
embeddings_dict = { | |
str(entity.id): embedding | |
for entity, embedding in zip(filtering_entities, embeddings) | |
if embedding is not None | |
} | |
else: | |
logger.warning(f"Vectorizer returned None for document {document.name}") | |
# Сохраняем в базу | |
await self.chunk_repository.add_entities_async(entities, dataset_id, embeddings_dict) | |
logger.info(f"Added {len(entities)} entities to dataset {dataset_id}") | |
async def add_entities_batch_async( | |
self, | |
dataset_id: int, | |
entities: list[LinkerEntity], | |
embeddings: dict[str, np.ndarray], | |
): | |
"""Асинхронно добавляет батч сущностей и их эмбеддингов в БД.""" | |
if not entities: | |
logger.info("add_entities_batch_async called with empty entities list. Nothing to add.") | |
return | |
logger.info(f"Starting batch insertion of {len(entities)} entities for dataset {dataset_id}...") | |
try: | |
await asyncio.to_thread( | |
self.chunk_repository.add_entities, | |
entities, | |
dataset_id, | |
embeddings | |
) | |
logger.info(f"Batch insertion of {len(entities)} entities finished for dataset {dataset_id}.") | |
except Exception as e: | |
logger.error( | |
f"Error during batch insertion for dataset {dataset_id}: {e}", | |
exc_info=True, | |
) | |
raise e | |
async def prepare_document_data_async( | |
self, | |
document: ParsedDocument, | |
progress_callback: Optional[Callable] = None, | |
) -> tuple[list[LinkerEntity], dict[str, np.ndarray]]: | |
"""Асинхронно извлекает сущности и векторы для документа. | |
Не сохраняет данные в репозиторий, а возвращает их для последующей | |
батчевой обработки. | |
Args: | |
document: Документ для обработки. | |
progress_callback: Функция для отслеживания прогресса векторизации. | |
Returns: | |
Кортеж: (список извлеченных LinkerEntity, словарь эмбеддингов {id_str: embedding}). | |
""" | |
logger.debug(f"Preparing data for document {document.name}") | |
# 1. Извлечение сущностей | |
if 'Приложение' in document.name: | |
entities = await self.appendices_extractor.extract_async(document) | |
else: | |
entities = await self.main_extractor.extract_async(document) | |
# 2. Векторизация (если нужно) | |
filtering_entities = [ | |
entity for entity in entities if entity.in_search_text is not None | |
] | |
filtering_texts = [entity.in_search_text for entity in filtering_entities] | |
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback) | |
embeddings_dict = {} | |
if embeddings is not None: | |
embeddings_dict = { | |
str(entity.id): embedding | |
for entity, embedding in zip(filtering_entities, embeddings) | |
if embedding is not None | |
} | |
else: | |
logger.warning(f"Vectorizer returned None for document {document.name}") | |
logger.debug(f"Prepared data for document {document.name}: {len(entities)} entities, {len(embeddings_dict)} embeddings.") | |
return entities, embeddings_dict | |
async def build_text_async( | |
self, | |
entities: list[str], | |
dataset_id: int, | |
chunk_scores: Optional[list[float]] = None, | |
include_tables: bool = True, | |
max_documents: Optional[int] = None, | |
) -> str: | |
""" | |
Асинхронная сборка текста из сущностей с использованием кешированного или основного репозитория. | |
Args: | |
entities: Список идентификаторов сущностей (строки UUID) | |
dataset_id: ID датасета для получения репозитория (кешированного или БД) | |
chunk_scores: Список весов чанков (соответствует порядку entities) | |
include_tables: Флаг включения таблиц | |
max_documents: Максимальное количество документов | |
Returns: | |
Собранный текст | |
""" | |
if not entities: | |
logger.warning("build_text called with empty entities list.") | |
return "" | |
try: | |
entity_ids = [UUID(entity) for entity in entities] | |
except ValueError as e: | |
logger.error(f"Invalid UUID format found in entities list: {e}") | |
raise ValueError(f"Invalid UUID format in entities list: {entities}") from e | |
repository = self._get_repository_for_dataset(dataset_id) | |
# Передаем репозиторий (кеш или БД) в InjectionBuilder | |
builder = InjectionBuilder(repository=repository) | |
# Создаем словарь score_map UUID -> score, если chunk_scores предоставлены | |
scores_map: dict[UUID, float] | None = None | |
if chunk_scores is not None: | |
if len(entity_ids) == len(chunk_scores): | |
scores_map = {eid: score for eid, score in zip(entity_ids, chunk_scores)} | |
else: | |
logger.warning(f"Length mismatch between entities ({len(entity_ids)}) and chunk_scores ({len(chunk_scores)}). Scores ignored.") | |
logger.info(f"Building text for {len(entity_ids)} entities from dataset {dataset_id} using {repository.__class__.__name__}") | |
# Вызываем асинхронный метод сборщика | |
return await builder.build_async( | |
entities=entity_ids, # Передаем список UUID | |
scores=scores_map, # Передаем словарь UUID -> score | |
include_tables=include_tables, | |
neighbors_max_distance=self.neighbors_max_distance, | |
max_documents=max_documents, | |
) | |
def search_similar_old( | |
self, | |
query: str, | |
dataset_id: int, | |
k: int | None = None, | |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
""" | |
Поиск похожих сущностей. | |
Args: | |
query: Текст запроса | |
dataset_id: ID датасета | |
k: Максимальное количество возвращаемых результатов (по умолчанию - все). | |
Returns: | |
tuple[np.ndarray, np.ndarray, np.ndarray]: | |
- Вектор запроса | |
- Оценки сходства | |
- Идентификаторы найденных сущностей | |
""" | |
logger.info(f"Searching similar entities for dataset {dataset_id} with k={k}") | |
self._ensure_faiss_initialized(dataset_id) | |
if self.faiss_search is None: | |
logger.warning( | |
f"FAISS search not initialized for dataset {dataset_id}. Returning empty results." | |
) | |
return np.array([]), np.array([]), np.array([]) | |
# Выполняем поиск с использованием параметра k | |
query_vector, scores, ids = self.faiss_search.search_vectors(query, max_entities=k) | |
logger.info(f"Found {len(ids)} similar entities.") | |
return query_vector, scores, ids | |
def search_similar( | |
self, | |
query: str, | |
dataset_id: int, | |
previous_entities: list[list[str]] = None, | |
) -> tuple[list[list[str]], list[str], list[float]]: | |
""" | |
Поиск похожих сущностей. | |
Args: | |
query: Текст запроса | |
dataset_id: ID датасета | |
previous_entities: Список идентификаторов сущностей, которые уже были найдены | |
Returns: | |
tuple[list[list[str]], list[str], list[float]]: | |
- Перефильтрованный список идентификаторов сущностей из прошлых запросов | |
- Список идентификаторов найденных сущностей (строки UUID) | |
- Скоры найденных сущностей | |
""" | |
self._ensure_faiss_initialized(dataset_id) | |
if self.faiss_search is None: | |
return previous_entities, [], [] | |
if ( | |
sum(len(entities) for entities in previous_entities) | |
< self.max_entities_per_dialogue - self.max_entities_per_message | |
): | |
_, scores, ids = self.faiss_search.search_vectors( | |
query, self.max_entities_per_message | |
) | |
try: | |
scores = scores.tolist() | |
ids = ids.tolist() | |
except: | |
scores = list(scores) | |
ids = list(ids) | |
return previous_entities, ids, scores | |
if previous_entities: | |
_, scores, ids = self.faiss_search.search_vectors( | |
query, self.max_entities_per_dialogue | |
) | |
scores = scores.tolist() | |
ids = ids.tolist() | |
print(ids) | |
previous_entities_ids = [ | |
[entity for entity in sublist if entity in ids] | |
for sublist in previous_entities | |
] | |
previous_entities_flat = [ | |
entity for sublist in previous_entities_ids for entity in sublist | |
] | |
new_entities = [] | |
new_scores = [] | |
for id_, score in zip(ids, scores): | |
if id_ not in previous_entities_flat: | |
new_entities.append(id_) | |
new_scores.append(score) | |
if len(new_entities) >= self.max_entities_per_message: | |
break | |
return previous_entities, new_entities, new_scores | |
else: | |
_, scores, ids = self.faiss_search.search_vectors( | |
query, self.max_entities_per_dialogue | |
) | |
scores = scores.tolist() | |
ids = ids.tolist() | |
return [], ids, scores | |