medrag / medrag_multi_modal /semantic_chunking.py
geekyrakshit's picture
update: app
170d9a9
raw
history blame
5.12 kB
import asyncio
from typing import Callable, Optional, Union
import huggingface_hub
import semchunk
import tiktoken
import tokenizers
from datasets import Dataset, concatenate_datasets, load_dataset
from rich.progress import track
from transformers import PreTrainedTokenizer
TOKENIZER_OR_TOKEN_COUNTER = Union[
str,
tiktoken.Encoding,
PreTrainedTokenizer,
tokenizers.Tokenizer,
Callable[[str], int],
]
class SemanticChunker:
"""
SemanticChunker is a class that chunks documents into smaller segments and
publishes them as datasets.
This class uses the `semchunk` library to break down large documents into
smaller, manageable chunks based on a specified tokenizer or token counter.
This is particularly useful for processing large text datasets where
smaller segments are needed for analysis or other operations.
!!! example "Example Usage"
```python
from medrag_multi_modal.semantic_chunking import SemanticChunker
chunker = SemanticChunker(chunk_size=256)
chunker.chunk(
document_dataset="geekyrakshit/grays-anatomy-test",
chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test",
)
```
Args:
tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or
token counter to be used for chunking.
chunk_size (Optional[int]): The size of each chunk. If not specified, the
default chunk size from `semchunk` will be used.
max_token_chars (Optional[int]): The maximum number of characters per token.
If not specified, the default value from `semchunk` will be used.
memoize (bool): Whether to memoize the chunking process for efficiency.
Default is True.
"""
def __init__(
self,
tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base",
chunk_size: Optional[int] = None,
max_token_chars: Optional[int] = None,
memoize: bool = True,
) -> None:
self.chunker = semchunk.chunkerify(
tokenizer_or_token_counter,
chunk_size=chunk_size,
max_token_chars=max_token_chars,
memoize=memoize,
)
def chunk(
self,
document_dataset: Union[Dataset, str],
chunk_dataset_repo_id: Optional[str] = None,
overwrite_dataset: bool = False,
) -> Dataset:
"""
Chunks a document dataset into smaller segments and publishes them as a new dataset.
This function takes a document dataset, either as a HuggingFace Dataset object or a string
representing the dataset repository ID, and chunks the documents into smaller segments using
the specified chunker. The resulting chunks are then optionally published to a HuggingFace
dataset repository.
Args:
document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either
a HuggingFace Dataset object or a string representing the dataset repository ID.
chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish
the chunks to, if provided. Defaults to None.
overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
Returns:
Dataset: A HuggingFace Dataset object containing the chunks.
"""
document_dataset = (
load_dataset(document_dataset, split="corpus")
if isinstance(document_dataset, str)
else document_dataset
).to_list()
chunks = []
async def process_document(idx, document):
document_chunks = self.chunker.chunk(str(document["text"]))
for chunk in document_chunks:
chunk_dict = {"document_idx": idx, "text": chunk}
for key, value in document.items():
if key not in chunk_dict:
chunk_dict[key] = value
chunks.append(chunk_dict)
async def process_all_documents():
tasks = []
for idx, document in track(
enumerate(document_dataset),
total=len(document_dataset),
description="Chunking documents",
):
tasks.append(process_document(idx, document))
await asyncio.gather(*tasks)
asyncio.run(process_all_documents())
chunks.sort(key=lambda x: x["document_idx"])
dataset = Dataset.from_list(chunks)
if chunk_dataset_repo_id:
if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"):
if not overwrite_dataset:
dataset = concatenate_datasets(
[
dataset,
load_dataset(chunk_dataset_repo_id, split="chunks"),
]
)
dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks")
return dataset