broken.god / scripts /create_reddit.py
amatiger's picture
Upload 6 files
59d97af
raw
history blame contribute delete
No virus
2.04 kB
#!/usr/bin/env python
# coding=utf-8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import jsonlines
import fire
def _norm_text(text):
w, *toks = text.strip().split()
try:
w = float(w)
except Exception:
toks = [w] + toks
w = 1.0
return w, ' '.join(toks)
def _get_inputs_from_text(text):
srcs, tgt = text.strip().split('\t')
weights = []
inputs = []
for src in srcs.split(' EOS '):
src_weight, src = _norm_text(src)
weights.append(src_weight)
inputs.append(src)
tgt_weight, tgt = _norm_text(tgt)
if tgt_weight != 0:
weights.append(tgt_weight)
inputs.append(tgt)
return weights, inputs
def process(reddit_path):
idx = 0
writer = jsonlines.open('../data/reddit_session_level.jsonl', 'w')
with open(reddit_path, "r", encoding="utf-8") as reader:
for line in reader:
idx += 1
if idx % 10000 == 0:
print(idx)
weights, inputs = _get_inputs_from_text(line)
if 0.0 in weights:
continue
else:
writer.write({'text': ' EOS '.join(inputs)})
idx = 0
with open('../data/reddit_session_level.jsonl', "r", encoding="utf-8") as reader:
writer = jsonlines.open('../data/reddit.jsonl', mode='w')
for item in jsonlines.Reader(reader):
idx += 1
if idx % 10000 == 0:
print(idx)
context = item['text'].split('EOS')
for idx in range(0, len(context)-1):
history = 'EOS'.join(context[:idx+1])
response = context[idx+1]
if len(history) == 0:
continue
example = {}
example['Context'] = history
example['Knowledge'] = ''
example['Response'] = response.strip()
writer.write(example)
def main():
fire.Fire(process)
if __name__ == '__main__':
main()