import json import numpy as np from tqdm import tqdm from langchain.text_splitter import RecursiveCharacterTextSplitter from models import env_config class Chunker: def __init__(self, strategy, split_seq=".", chunk_len=512): self.split_seq = split_seq self.chunk_len = chunk_len if strategy == "recursive": self.split = RecursiveCharacterTextSplitter( chunk_size=chunk_len, separators=[split_seq] ).split_text if strategy == "sequence": self.split = self.seq_splitter if strategy == "constant": self.split = self.const_splitter def seq_splitter(self, text): return text.split(self.split_seq) def const_splitter(self, text): return [ text[i * self.chunk_len:(i + 1) * self.chunk_len] for i in range(int(np.ceil(len(text) / self.chunk_len))) ] def chunk_generator(input_dataset, chunker, tmp_file): for i in tqdm(range(len(input_dataset))): chunks = chunker.split(input_dataset[i][env_config.input_text_col]) for chunk in chunks: if chunk: tmp_file.write( json.dumps({env_config.input_text_col: chunk}) + "\n" ) yield {env_config.input_text_col: chunk}