File size: 1,338 Bytes
8b7a023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import json
import numpy as np

from tqdm import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter

from models import env_config


class Chunker:
    def __init__(self, strategy, split_seq=".", chunk_len=512):
        self.split_seq = split_seq
        self.chunk_len = chunk_len
        if strategy == "recursive":
            self.split = RecursiveCharacterTextSplitter(
                chunk_size=chunk_len,
                separators=[split_seq]
            ).split_text
        if strategy == "sequence":
            self.split = self.seq_splitter
        if strategy == "constant":
            self.split = self.const_splitter

    def seq_splitter(self, text):
        return text.split(self.split_seq)

    def const_splitter(self, text):
        return [
            text[i * self.chunk_len:(i + 1) * self.chunk_len]
            for i in range(int(np.ceil(len(text) / self.chunk_len)))
        ]


def chunk_generator(input_dataset, chunker, tmp_file):
    for i in tqdm(range(len(input_dataset))):
        chunks = chunker.split(input_dataset[i][env_config.input_text_col])
        for chunk in chunks:
            if chunk:
                tmp_file.write(
                    json.dumps({env_config.input_text_col: chunk}) + "\n"
                )
                yield {env_config.input_text_col: chunk}