Spaces:
Runtime error
Runtime error
import json | |
import numpy as np | |
from tqdm import tqdm | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from models import env_config | |
class Chunker: | |
def __init__(self, strategy, split_seq=".", chunk_len=512): | |
self.split_seq = split_seq | |
self.chunk_len = chunk_len | |
if strategy == "recursive": | |
self.split = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_len, | |
separators=[split_seq] | |
).split_text | |
if strategy == "sequence": | |
self.split = self.seq_splitter | |
if strategy == "constant": | |
self.split = self.const_splitter | |
def seq_splitter(self, text): | |
return text.split(self.split_seq) | |
def const_splitter(self, text): | |
return [ | |
text[i * self.chunk_len:(i + 1) * self.chunk_len] | |
for i in range(int(np.ceil(len(text) / self.chunk_len))) | |
] | |
def chunk_generator(input_dataset, chunker, tmp_file): | |
for i in tqdm(range(len(input_dataset))): | |
chunks = chunker.split(input_dataset[i][env_config.input_text_col]) | |
for chunk in chunks: | |
if chunk: | |
tmp_file.write( | |
json.dumps({env_config.input_text_col: chunk}) + "\n" | |
) | |
yield {env_config.input_text_col: chunk} |