Spaces:
Runtime error
Runtime error
| # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| from camel.loaders import UnstructuredIO | |
| from camel.retrievers import BaseRetriever | |
| from camel.utils import dependencies_required | |
| DEFAULT_TOP_K_RESULTS = 1 | |
| class BM25Retriever(BaseRetriever): | |
| r"""An implementation of the `BaseRetriever` using the `BM25` model. | |
| This class facilitates the retriever of relevant information using a | |
| query-based approach, it ranks documents based on the occurrence and | |
| frequency of the query terms. | |
| Attributes: | |
| bm25 (BM25Okapi): An instance of the BM25Okapi class used for | |
| calculating document scores. | |
| content_input_path (str): The path to the content that has been | |
| processed and stored. | |
| unstructured_modules (UnstructuredIO): A module for parsing files and | |
| URLs and chunking content based on specified parameters. | |
| References: | |
| https://github.com/dorianbrown/rank_bm25 | |
| """ | |
| def __init__(self) -> None: | |
| r"""Initializes the BM25Retriever.""" | |
| from rank_bm25 import BM25Okapi | |
| self.bm25: BM25Okapi = None | |
| self.content_input_path: str = "" | |
| self.unstructured_modules: UnstructuredIO = UnstructuredIO() | |
| def process( | |
| self, | |
| content_input_path: str, | |
| chunk_type: str = "chunk_by_title", | |
| **kwargs: Any, | |
| ) -> None: | |
| r"""Processes content from a file or URL, divides it into chunks by | |
| using `Unstructured IO`,then stored internally. This method must be | |
| called before executing queries with the retriever. | |
| Args: | |
| content_input_path (str): File path or URL of the content to be | |
| processed. | |
| chunk_type (str): Type of chunking going to apply. Defaults to | |
| "chunk_by_title". | |
| **kwargs (Any): Additional keyword arguments for content parsing. | |
| """ | |
| from rank_bm25 import BM25Okapi | |
| # Load and preprocess documents | |
| self.content_input_path = content_input_path | |
| elements = self.unstructured_modules.parse_file_or_url( | |
| content_input_path, **kwargs | |
| ) | |
| if elements: | |
| self.chunks = self.unstructured_modules.chunk_elements( | |
| chunk_type=chunk_type, elements=elements | |
| ) | |
| # Convert chunks to a list of strings for tokenization | |
| tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks] | |
| self.bm25 = BM25Okapi(tokenized_corpus) | |
| else: | |
| self.bm25 = None | |
| def query( | |
| self, | |
| query: str, | |
| top_k: int = DEFAULT_TOP_K_RESULTS, | |
| ) -> List[Dict[str, Any]]: | |
| r"""Executes a query and compiles the results. | |
| Args: | |
| query (str): Query string for information retriever. | |
| top_k (int, optional): The number of top results to return during | |
| retriever. Must be a positive integer. Defaults to | |
| `DEFAULT_TOP_K_RESULTS`. | |
| Returns: | |
| List[Dict[str]]: Concatenated list of the query results. | |
| Raises: | |
| ValueError: If `top_k` is less than or equal to 0, if the BM25 | |
| model has not been initialized by calling `process` | |
| first. | |
| """ | |
| if top_k <= 0: | |
| raise ValueError("top_k must be a positive integer.") | |
| if self.bm25 is None or not self.chunks: | |
| raise ValueError( | |
| "BM25 model is not initialized. Call `process` first." | |
| ) | |
| # Preprocess query similarly to how documents were processed | |
| processed_query = query.split(" ") | |
| # Retrieve documents based on BM25 scores | |
| scores = self.bm25.get_scores(processed_query) | |
| top_k_indices = np.argpartition(scores, -top_k)[-top_k:] | |
| formatted_results = [] | |
| for i in top_k_indices: | |
| result_dict = { | |
| 'similarity score': scores[i], | |
| 'content path': self.content_input_path, | |
| 'metadata': self.chunks[i].metadata.to_dict(), | |
| 'text': str(self.chunks[i]), | |
| } | |
| formatted_results.append(result_dict) | |
| # Sort the list of dictionaries by 'similarity score' from high to low | |
| formatted_results.sort( | |
| key=lambda x: x['similarity score'], reverse=True | |
| ) | |
| return formatted_results | |