| | |
| | """ |
| | UnifiedEncoder - Unified text encoder |
| | Integrates sentence splitting and multiple encoding models into a unified interface |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| | import pickle |
| | import os |
| | from typing import List, Tuple, Union |
| | from .sentenizer import Sentenceizer |
| | from .freechunker import FreeChunkerModel |
| | from .aggregator import TextAggregator |
| | from . import utils |
| |
|
| | class UnifiedEncoder: |
| | """ |
| | Unified text encoder, supporting text sentence splitting and encoding for multiple models |
| | """ |
| | |
| | def __init__(self, model_name: str, model_name_or_path: str = None, granularities: List[int] = None, **kwargs): |
| | """ |
| | Initialize unified text encoder |
| | |
| | Args: |
| | model_name (str): Model name |
| | model_name_or_path (str, optional): Model path or HF Hub ID |
| | granularities (List[int], optional): Granularities for chunking |
| | """ |
| | self.model_name = model_name |
| | self.granularities = granularities |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps') |
| | |
| | |
| | self.aggregator = TextAggregator() |
| | |
| | print(f"Initializing unified text encoder, model: {model_name}") |
| | print(f"Using model path: {model_name_or_path}") |
| | print(f"Using device: {self.device}") |
| |
|
| | |
| | if model_name_or_path is None: |
| | model_name_or_path = model_name |
| |
|
| | self.model = FreeChunkerModel.from_pretrained(model_name_or_path, **kwargs) |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | |
| | |
| | model_configs = { |
| | 'bge-m3': 'BAAI/bge-m3', |
| | 'nomic-embed-text-v1.5': 'nomic-ai/nomic-embed-text-v1.5', |
| | 'jina': 'jinaai/jina-embeddings-v2-small-en' |
| | } |
| |
|
| | if model_name in model_configs: |
| | hf_id = model_configs[model_name] |
| | self.sentenceizer = Sentenceizer(model_name=hf_id) |
| | else: |
| | |
| | print(f"Unknown predefined model name: {model_name}, trying to load directly...") |
| | self.sentenceizer = Sentenceizer(model_name=model_name) |
| | |
| | print("Unified text encoder initialized!") |
| |
|
| | @classmethod |
| | def from_pretrained(cls, model_name_or_path: str, model_name: str = None, **kwargs): |
| | """ |
| | Load UnifiedEncoder from a pretrained model |
| | |
| | Args: |
| | model_name_or_path (str): HF Hub ID or local path |
| | model_name (str, optional): Backbone model name (e.g. 'nomic-embed-text-v1.5'). |
| | If not provided, defaults to model_name_or_path. |
| | """ |
| | if model_name is None: |
| | |
| | model_name = "nomic-embed-text-v1.5" |
| | |
| | return cls(model_name=model_name, model_name_or_path=model_name_or_path, **kwargs) |
| |
|
| | @classmethod |
| | def register_for_auto_class(cls, auto_class="AutoModel"): |
| | return |
| |
|
| | def encode(self, text: str, show_progress: bool = True) -> Tuple[List[str], np.ndarray, List[List[str]]]: |
| | """ |
| | Split text and encode, return results grouped by shift_matrix |
| | |
| | Args: |
| | text (str): Input text |
| | show_progress (bool): Whether to show progress |
| | |
| | Returns: |
| | Tuple[List[str], np.ndarray, List[List[str]]]: (Original sentence list, encoded vector array, grouped sentence list by shift_matrix) |
| | """ |
| | with torch.no_grad(): |
| | sentences, input_embeddings = self.sentenceizer.split_and_encode(text, show_progress=show_progress) |
| | |
| | if len(sentences) == 0: |
| | return sentences, np.array([]), [] |
| | if isinstance(input_embeddings, np.ndarray): |
| | input_embeddings = torch.from_numpy(input_embeddings) |
| | input_embeddings = input_embeddings.to(self.device) |
| | inputs_embeds = input_embeddings.unsqueeze(0) |
| | outputs = self.model(inputs_embeds=inputs_embeds, granularities=self.granularities) |
| | final_embeddings = outputs['embedding'] |
| | shift_matrix = outputs['shift_matrix'] |
| | |
| | |
| | sentences = [f"【Begin-{num}】" + sentence + f"【End-{num}】" for num, sentence in enumerate(sentences)] |
| | grouped_sentences = self._group_sentences_by_shift_matrix(sentences, shift_matrix) |
| | result_embeddings = final_embeddings.cpu().numpy() |
| | |
| | return sentences, result_embeddings, grouped_sentences |
| | |
| | def _group_sentences_by_shift_matrix(self, sentences: List[str], shift_matrix: torch.Tensor) -> List[List[str]]: |
| | """ |
| | Group sentences according to shift_matrix (Optimized version) |
| | |
| | Args: |
| | sentences (List[str]): Original sentence list |
| | shift_matrix (torch.Tensor): Mask matrix with shape [num_chunks, seq_len] |
| | |
| | Returns: |
| | List[List[str]]: List of sentences grouped by shift_matrix |
| | """ |
| | |
| | grouped_sentences = [] |
| | num_chunks, seq_len = shift_matrix.shape |
| | |
| | for chunk_idx in range(num_chunks): |
| | chunk_mask = shift_matrix[chunk_idx] |
| | |
| | |
| | valid_indices = (chunk_mask == 1).nonzero(as_tuple=True)[0].cpu().numpy() |
| | |
| | |
| | valid_indices = valid_indices[valid_indices < len(sentences)] |
| | |
| | if len(valid_indices) > 0: |
| | |
| | chunk_sentences = [sentences[idx] for idx in valid_indices] |
| | grouped_sentences.append(chunk_sentences) |
| | |
| | return grouped_sentences |
| |
|
| | def build_vector_store(self, text: str, show_progress: bool = True): |
| | """ |
| | Build vector store based on long text |
| | |
| | Args: |
| | text (str): Long text |
| | show_progress (bool): Whether to show progress |
| | """ |
| | |
| | sentences, embeddings, grouped_sentences = self.encode(text, show_progress) |
| | |
| | |
| |
|
| | grouped_texts = sentences + [" ".join(group) if isinstance(group, list) else str(group) for group in grouped_sentences] |
| | |
| | self.vector_store = { |
| | 'sentences': sentences, |
| | 'embeddings': embeddings, |
| | 'grouped_sentences': grouped_sentences, |
| | 'grouped_texts': grouped_texts |
| | } |
| | |
| | if show_progress: |
| | print(f"Vector store built: {len(sentences)} original sentences, {len(grouped_sentences)} groups, {len(embeddings)} embedding vectors") |
| | print(f"Vector store verification: embeddings.shape={embeddings.shape}, grouped_texts count={len(grouped_texts)}\n") |
| | |
| | def query(self, query: str, top_k: int = 5, aggregation_mode: str = 'post', tokenizer=None) -> Union[List[Tuple[str, float]], str]: |
| | """ |
| | Query vector store |
| | |
| | Args: |
| | query (str): Query text |
| | top_k (int): Return top k most similar results |
| | aggregation_mode (str): Aggregation mode |
| | - 'none': No aggregation, return top_k results directly [(text, score), ...] |
| | - 'post': Post-aggregation mode, return aggregated text string |
| | |
| | Returns: |
| | Union[List[Tuple[str, float]], str]: |
| | - If aggregation_mode='none', return [(sentence, similarity_score), ...] |
| | - If aggregation_mode='post', return aggregated string |
| | """ |
| | if not hasattr(self, 'vector_store'): |
| | raise ValueError("Vector store not built, please call build_vector_store method first") |
| | |
| | |
| | query_embeddings = self.sentenceizer.encode([query]) |
| | query_embedding = query_embeddings[0] |
| |
|
| | |
| | similarities = np.dot(self.vector_store['embeddings'], query_embedding) |
| | |
| | |
| | sorted_indices = np.argsort(similarities)[::-1] |
| | |
| | if aggregation_mode == 'none': |
| | return self._get_direct_results(sorted_indices, similarities, top_k) |
| | elif aggregation_mode == 'post': |
| | return self._post_aggregation(sorted_indices, similarities, top_k, tokenizer=tokenizer) |
| | else: |
| | print(f"Warning: Unknown aggregation_mode '{aggregation_mode}', falling back to 'none'") |
| | return self._get_direct_results(sorted_indices, similarities, top_k) |
| | |
| | def _get_direct_results(self, sorted_indices: np.ndarray, similarities: np.ndarray, top_k: int) -> List[Tuple[str, float]]: |
| | |
| | available_count = len(self.vector_store['grouped_texts']) |
| | actual_top_k = min(top_k, available_count) |
| | top_indices = sorted_indices[:actual_top_k] |
| | |
| | results = [] |
| | for idx in top_indices: |
| | if idx < len(self.vector_store['grouped_texts']): |
| | grouped_text = self.vector_store['grouped_texts'][idx] |
| | score = similarities[idx] |
| | results.append((grouped_text, float(score))) |
| | |
| | return results |
| | |
| | def _post_aggregation(self, sorted_indices: np.ndarray, similarities: np.ndarray, top_k: int, tokenizer=None) -> List[Tuple[str, float]]: |
| | |
| | |
| | direct_results = self._get_direct_results(sorted_indices, similarities, top_k) |
| | |
| | |
| | texts = [text for text, score in direct_results] |
| | |
| | aggregated_texts = self.aggregator.aggregate_segments(texts) |
| | |
| | |
| | return aggregated_texts |
| | |
| | |
| | def load_vector_store(self, file_path: str): |
| | """ |
| | Load vector store from file |
| | |
| | Args: |
| | file_path (str): Vector store file path |
| | """ |
| | if not os.path.exists(file_path): |
| | raise FileNotFoundError(f"Vector store file not found: {file_path}") |
| | |
| | with open(file_path, 'rb') as f: |
| | self.vector_store = pickle.load(f) |
| | |
| | print(f"Vector store loaded from {file_path}") |
| | print(f"Vector store info: {len(self.vector_store['grouped_texts'])} groups, embedding dimension: {self.vector_store['embeddings'].shape}") |
| | |
| | def has_vector_store(self) -> bool: |
| | """ |
| | Check if vector store is built or loaded |
| | |
| | Returns: |
| | bool: Whether a vector store is available |
| | """ |
| | return hasattr(self, 'vector_store') and self.vector_store is not None |
| |
|