NilanE commited on
Commit
e58dc92
1 Parent(s): 018b055

Delete datasetChunker.py

Browse files
Files changed (1) hide show
  1. datasetChunker.py +0 -99
datasetChunker.py DELETED
@@ -1,99 +0,0 @@
1
- from transformers import AutoTokenizer
2
- import jsonlines
3
- import random
4
- import os
5
-
6
- tokenizer = AutoTokenizer.from_pretrained("NilanE/tinyllama-relora-merge")
7
-
8
- max_seq_len = 2048 # max context length
9
-
10
- prompt = "Translate this from Japanese to English:\n### JAPANESE: \n### ENGLISH: </s>" # insert SFT prompt to add to token count
11
-
12
- input_file_path = "dataset-parallel-complete.jsonl"
13
-
14
- output_file_path = input_file_path.split('.')[0] + "-chunked." + input_file_path.split('.')[1]
15
- promptTokens = len(tokenizer.tokenize(prompt))
16
-
17
- def load_jsonl(file_path):
18
- data = []
19
- with jsonlines.open(file_path) as reader:
20
- for entry in reader:
21
- source = entry['src'].replace('</s>', '').strip()
22
- target = entry['trg'].replace('</s>', '').strip()
23
- data.append([source, target])
24
- return data
25
-
26
- def save_jsonl(file_path, data):
27
- with jsonlines.open(file_path, 'w') as writer:
28
- writer.write_all(data)
29
-
30
- chunks = []
31
-
32
- data = load_jsonl(input_file_path)
33
-
34
- #tolerance
35
- max_seq_len -= 10
36
-
37
- skippedDocs = 0
38
-
39
- for doc in data:
40
-
41
- src_lines = doc[0].split('\n')
42
- trg_lines = doc[1].split('\n')
43
-
44
- out_src = []
45
- out_trg = []
46
- tokenCount = 0
47
- lastTokenCount = 0
48
- longLines = 0
49
-
50
- try:
51
- for x in range(len(src_lines)):
52
- out_src.append(src_lines[x])
53
- out_trg.append(trg_lines[x])
54
- out_src_string = "\n".join(out_src)
55
- trg_src_string = "\n".join(out_trg)
56
- tokenCount = len(tokenizer.tokenize(out_src_string.strip() + trg_src_string.strip())) + promptTokens
57
- if tokenCount-lastTokenCount < max_seq_len-1: # avoid lines > max line length
58
- if tokenCount > max_seq_len-1:
59
- src_end = out_src.pop()
60
- trg_end = out_trg.pop()
61
- out_src_string = "\n".join(out_src)
62
- trg_src_string = "\n".join(out_trg)
63
- data = {
64
- 'src' : out_src_string.strip(),
65
- 'trg' : trg_src_string.strip()
66
- }
67
- chunks.append(data)
68
- out_src = [src_end]
69
- out_trg = [trg_end]
70
- elif x+1 == len(src_lines): #and len(out_src) > 2:
71
- data = {
72
- 'src' : out_src_string.strip(),
73
- 'trg' : trg_src_string.strip()
74
- }
75
- chunks.append(data)
76
- else:
77
- # remove offending line > max_seq_len
78
- out_src.pop()
79
- out_trg.pop()
80
- out_src_string = "\n".join(out_src)
81
- trg_src_string = "\n".join(out_trg)
82
- tokenCount = len(tokenizer.tokenize(prompt + out_src_string.strip() + trg_src_string.strip()))
83
- longLines += 1
84
-
85
- lastTokenCount = tokenCount
86
- except:
87
- skippedDocs += 1
88
-
89
-
90
- random.shuffle(chunks)
91
-
92
- print(f"LINES LONGER THAN MAX SEQUENCE LENTH: {longLines}")
93
- print(f"SKIPPED DOCS: {skippedDocs}")
94
-
95
- # Save the randomized data to a new JSONL file
96
- if os.path.exists(output_file_path):
97
- os.remove(output_file_path)
98
- save_jsonl(output_file_path, chunks)
99
-