|
|
|
|
|
import os |
|
import argparse |
|
import torch |
|
|
|
import transformers |
|
from src.normalize_text import normalize |
|
|
|
|
|
def save(tensor, split_path): |
|
if not os.path.exists(os.path.dirname(split_path)): |
|
os.makedirs(os.path.dirname(split_path)) |
|
with open(split_path, 'wb') as fout: |
|
torch.save(tensor, fout) |
|
|
|
def apply_tokenizer(path, tokenizer, normalize_text=False): |
|
alltokens = [] |
|
lines = [] |
|
with open(path, "r", encoding="utf-8") as fin: |
|
for k, line in enumerate(fin): |
|
if normalize_text: |
|
line = normalize(line) |
|
|
|
lines.append(line) |
|
if len(lines) > 1000000: |
|
tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids'] |
|
tokens = [torch.tensor(x, dtype=torch.int) for x in tokens] |
|
alltokens.extend(tokens) |
|
lines = [] |
|
|
|
tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids'] |
|
tokens = [torch.tensor(x, dtype=torch.int) for x in tokens] |
|
alltokens.extend(tokens) |
|
|
|
alltokens = torch.cat(alltokens) |
|
return alltokens |
|
|
|
def tokenize_file(args): |
|
filename = os.path.basename(args.datapath) |
|
savepath = os.path.join(args.outdir, f"{filename}.pkl") |
|
if os.path.exists(savepath): |
|
if args.overwrite: |
|
print(f"File {savepath} already exists, overwriting") |
|
else: |
|
print(f"File {savepath} already exists, exiting") |
|
return |
|
try: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=True) |
|
except: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=False) |
|
print(f"Encoding {args.datapath}...") |
|
tokens = apply_tokenizer(args.datapath, tokenizer, normalize_text=args.normalize_text) |
|
|
|
print(f"Saving at {savepath}...") |
|
save(tokens, savepath) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
parser.add_argument("--datapath", type=str) |
|
parser.add_argument("--outdir", type=str) |
|
parser.add_argument("--tokenizer", type=str) |
|
parser.add_argument("--overwrite", action="store_true") |
|
parser.add_argument("--normalize_text", action="store_true") |
|
|
|
args, _ = parser.parse_known_args() |
|
tokenize_file(args) |
|
|