File size: 2,044 Bytes
59d97af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python
#  coding=utf-8
#  Copyright (c) Microsoft Corporation.
#  Licensed under the MIT license.

import jsonlines
import fire


def _norm_text(text):
    w, *toks = text.strip().split()
    try:
        w = float(w)
    except Exception:
        toks = [w] + toks
        w = 1.0
    return w, ' '.join(toks)


def _get_inputs_from_text(text):
    srcs, tgt = text.strip().split('\t')
    weights = []
    inputs = []
    for src in srcs.split(' EOS '):
        src_weight, src = _norm_text(src)
        weights.append(src_weight)
        inputs.append(src)
    tgt_weight, tgt = _norm_text(tgt)
    if tgt_weight != 0:
        weights.append(tgt_weight)
        inputs.append(tgt)
    return weights, inputs


def process(reddit_path):

    idx = 0
    writer = jsonlines.open('../data/reddit_session_level.jsonl', 'w')
    with open(reddit_path, "r", encoding="utf-8") as reader:
        for line in reader:
            idx += 1
            if idx % 10000 == 0:
                print(idx)
            weights, inputs = _get_inputs_from_text(line)
            if 0.0 in weights:
                continue
            else:
                writer.write({'text': ' EOS '.join(inputs)})

    idx = 0
    with open('../data/reddit_session_level.jsonl', "r", encoding="utf-8") as reader:
        writer = jsonlines.open('../data/reddit.jsonl', mode='w')
        for item in jsonlines.Reader(reader):
            idx += 1
            if idx % 10000 == 0:
                print(idx)
            context = item['text'].split('EOS')

            for idx in range(0, len(context)-1):

                history = 'EOS'.join(context[:idx+1])
                response = context[idx+1]

                if len(history) == 0:
                    continue

                example = {}
                example['Context'] = history
                example['Knowledge'] = ''
                example['Response'] = response.strip()

                writer.write(example)


def main():
    fire.Fire(process)


if __name__ == '__main__':
    main()