webhook-space / chunking_utils.py
plaggy's picture
refactor
8b7a023
raw
history blame
1.34 kB
import json
import numpy as np
from tqdm import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
from models import env_config
class Chunker:
def __init__(self, strategy, split_seq=".", chunk_len=512):
self.split_seq = split_seq
self.chunk_len = chunk_len
if strategy == "recursive":
self.split = RecursiveCharacterTextSplitter(
chunk_size=chunk_len,
separators=[split_seq]
).split_text
if strategy == "sequence":
self.split = self.seq_splitter
if strategy == "constant":
self.split = self.const_splitter
def seq_splitter(self, text):
return text.split(self.split_seq)
def const_splitter(self, text):
return [
text[i * self.chunk_len:(i + 1) * self.chunk_len]
for i in range(int(np.ceil(len(text) / self.chunk_len)))
]
def chunk_generator(input_dataset, chunker, tmp_file):
for i in tqdm(range(len(input_dataset))):
chunks = chunker.split(input_dataset[i][env_config.input_text_col])
for chunk in chunks:
if chunk:
tmp_file.write(
json.dumps({env_config.input_text_col: chunk}) + "\n"
)
yield {env_config.input_text_col: chunk}