NilanE commited on
Commit
018b055
1 Parent(s): 92d3816

Upload datasetChunker.py

Browse files
Files changed (1) hide show
  1. datasetChunker.py +99 -0
datasetChunker.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import jsonlines
3
+ import random
4
+ import os
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("NilanE/tinyllama-relora-merge")
7
+
8
+ max_seq_len = 2048 # max context length
9
+
10
+ prompt = "Translate this from Japanese to English:\n### JAPANESE: \n### ENGLISH: </s>" # insert SFT prompt to add to token count
11
+
12
+ input_file_path = "dataset-parallel-complete.jsonl"
13
+
14
+ output_file_path = input_file_path.split('.')[0] + "-chunked." + input_file_path.split('.')[1]
15
+ promptTokens = len(tokenizer.tokenize(prompt))
16
+
17
+ def load_jsonl(file_path):
18
+ data = []
19
+ with jsonlines.open(file_path) as reader:
20
+ for entry in reader:
21
+ source = entry['src'].replace('</s>', '').strip()
22
+ target = entry['trg'].replace('</s>', '').strip()
23
+ data.append([source, target])
24
+ return data
25
+
26
+ def save_jsonl(file_path, data):
27
+ with jsonlines.open(file_path, 'w') as writer:
28
+ writer.write_all(data)
29
+
30
+ chunks = []
31
+
32
+ data = load_jsonl(input_file_path)
33
+
34
+ #tolerance
35
+ max_seq_len -= 10
36
+
37
+ skippedDocs = 0
38
+
39
+ for doc in data:
40
+
41
+ src_lines = doc[0].split('\n')
42
+ trg_lines = doc[1].split('\n')
43
+
44
+ out_src = []
45
+ out_trg = []
46
+ tokenCount = 0
47
+ lastTokenCount = 0
48
+ longLines = 0
49
+
50
+ try:
51
+ for x in range(len(src_lines)):
52
+ out_src.append(src_lines[x])
53
+ out_trg.append(trg_lines[x])
54
+ out_src_string = "\n".join(out_src)
55
+ trg_src_string = "\n".join(out_trg)
56
+ tokenCount = len(tokenizer.tokenize(out_src_string.strip() + trg_src_string.strip())) + promptTokens
57
+ if tokenCount-lastTokenCount < max_seq_len-1: # avoid lines > max line length
58
+ if tokenCount > max_seq_len-1:
59
+ src_end = out_src.pop()
60
+ trg_end = out_trg.pop()
61
+ out_src_string = "\n".join(out_src)
62
+ trg_src_string = "\n".join(out_trg)
63
+ data = {
64
+ 'src' : out_src_string.strip(),
65
+ 'trg' : trg_src_string.strip()
66
+ }
67
+ chunks.append(data)
68
+ out_src = [src_end]
69
+ out_trg = [trg_end]
70
+ elif x+1 == len(src_lines): #and len(out_src) > 2:
71
+ data = {
72
+ 'src' : out_src_string.strip(),
73
+ 'trg' : trg_src_string.strip()
74
+ }
75
+ chunks.append(data)
76
+ else:
77
+ # remove offending line > max_seq_len
78
+ out_src.pop()
79
+ out_trg.pop()
80
+ out_src_string = "\n".join(out_src)
81
+ trg_src_string = "\n".join(out_trg)
82
+ tokenCount = len(tokenizer.tokenize(prompt + out_src_string.strip() + trg_src_string.strip()))
83
+ longLines += 1
84
+
85
+ lastTokenCount = tokenCount
86
+ except:
87
+ skippedDocs += 1
88
+
89
+
90
+ random.shuffle(chunks)
91
+
92
+ print(f"LINES LONGER THAN MAX SEQUENCE LENTH: {longLines}")
93
+ print(f"SKIPPED DOCS: {skippedDocs}")
94
+
95
+ # Save the randomized data to a new JSONL file
96
+ if os.path.exists(output_file_path):
97
+ os.remove(output_file_path)
98
+ save_jsonl(output_file_path, chunks)
99
+