Spaces:
Running
Running
File size: 3,733 Bytes
aefc9ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# Copyright (c) Guangsheng Bao.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
import numpy as np
import datasets
import torch
import random
import argparse
import os
import json
import custom_datasets
from model import load_tokenizer, load_model
def stats_str(data):
if type(data) == dict:
mean_orig = np.mean([len(v.split()) for v in data['original']])
mean_samp = np.mean([len(v.split()) for v in data['sampled']])
return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled).'
else:
mean_orig = np.mean([len(v['original'].split()) for v in data])
mean_samp = np.mean([len(v['sampled'].split()) for v in data])
mean_perturb_orig = np.mean([np.mean([len(p.split()) for p in v['perturbed_original']]) for v in data])
mean_perturb_samp = np.mean([np.mean([len(p.split()) for p in v['perturbed_sampled']]) for v in data])
return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled), {mean_perturb_orig:.0f} words (perturb original), {mean_perturb_samp:.0f} words (perturb sampled).'
def save_data(output_file, args, data):
# write args to file
args_file = f"{output_file}.args.json"
with open(args_file, "w") as fout:
json.dump(args, fout, indent=4)
print(f"Args written into {args_file}")
# write the data to a json file in the save folder
data_file = f"{output_file}.raw_data.json"
with open(data_file, "w") as fout:
json.dump(data, fout, indent=4)
print(f"Raw data written into {data_file}: {stats_str(data)}")
def load_data(input_file):
# load args from file
args_file = f"{input_file}.args.json"
with open(args_file, "r") as fin:
args = json.load(fin)
print(f"Args loaded from {args_file}")
# load the data from file
data_file = f"{input_file}.raw_data.json"
with open(data_file, "r") as fin:
data = json.load(fin)
print(f"Raw data loaded from {data_file}: {stats_str(data)}")
return args, data
def convert_data(input_file, output_file, max_words):
def _reduce(text):
lines = []
nwords = 0
for line in text.split('\n'):
if nwords >= max_words:
break
words = line.split()
words = words[:max_words - nwords]
lines.append(' '.join(words))
nwords += len(words)
return '\n'.join(lines)
args, data = load_data(input_file)
if type(data) == dict:
data['original'] = [_reduce(x) for x in data['original']]
data['sampled'] = [_reduce(x) for x in data['sampled']]
else:
for item in data:
item['original'] = _reduce(item['original'])
item['sampled'] = _reduce(item['sampled'])
item['perturbed_original'] = [_reduce(x) for x in item['perturbed_original']]
item['perturbed_sampled'] = [_reduce(x) for x in item['perturbed_sampled']]
save_data(output_file, args, data)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str, default="./exp_gpt3to4/data/")
parser.add_argument('--output_path', type=str, default="./exp_maxlen150/data/")
parser.add_argument('--max_words', type=int, default=150)
args = parser.parse_args()
import glob
import os.path as path
for file_name in glob.glob(f'{args.input_path}/*.raw_data.json'):
print(file_name)
file_name = path.basename(file_name).replace('.raw_data.json', '')
convert_data(path.join(args.input_path, file_name), path.join(args.output_path, file_name), args.max_words)
|