File size: 1,792 Bytes
476ac07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

from mmengine.config import Config

from xtuner.registry import BUILDER


def parse_args():
    parser = argparse.ArgumentParser(description='Log processed dataset.')
    parser.add_argument('config', help='config file name or path.')
    # chose which kind of dataset style to show
    parser.add_argument(
        '--show',
        default='text',
        choices=['text', 'masked_text', 'input_ids', 'labels', 'all'],
        help='which kind of dataset style to show')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    tokenizer = BUILDER.build(cfg.tokenizer)
    if cfg.get('framework', 'mmengine').lower() == 'huggingface':
        train_dataset = BUILDER.build(cfg.train_dataset)
    else:
        train_dataset = BUILDER.build(cfg.train_dataloader.dataset)

    if args.show == 'text' or args.show == 'all':
        print('#' * 20 + '   text   ' + '#' * 20)
        print(tokenizer.decode(train_dataset[0]['input_ids']))
    if args.show == 'masked_text' or args.show == 'all':
        print('#' * 20 + '   text(masked)   ' + '#' * 20)
        masked_text = ' '.join(
            ['[-100]' for i in train_dataset[0]['labels'] if i == -100])
        unmasked_text = tokenizer.decode(
            [i for i in train_dataset[0]['labels'] if i != -100])
        print(masked_text + ' ' + unmasked_text)
    if args.show == 'input_ids' or args.show == 'all':
        print('#' * 20 + '   input_ids   ' + '#' * 20)
        print(train_dataset[0]['input_ids'])
    if args.show == 'labels' or args.show == 'all':
        print('#' * 20 + '   labels   ' + '#' * 20)
        print(train_dataset[0]['labels'])


if __name__ == '__main__':
    main()