File size: 3,374 Bytes
0da05b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a6584
0da05b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- coding: utf-8 -*-
import random
import argparse

import pygame
from PIL import Image
import numpy as np
import os
from pypinyin import lazy_pinyin


def set_seed():
    random.seed(args.seed)


def load_chars(path):
    # 获取所有汉字列表
    chars_list = []
    with open(path, 'r+', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            chars = [char for char in line.strip()]
            chars_list.append(chars)

    return chars_list


def char2image(chars_list):
    # 通过pygame将汉字转化为黑白图片
    pygame.init()
    output_dir_path = os.path.join('./', args.output_path)
    if not os.path.exists(output_dir_path):
        os.makedirs(output_dir_path)

    train_ids_list, test_ids_list = data_split(len(chars_list), args.train_percentage)

    for i, chars in enumerate(chars_list):
        if i in train_ids_list:
            chars_class_path = os.path.join(output_dir_path, 'train', f'character_{i + 1}')
        elif i in test_ids_list:
            chars_class_path = os.path.join(output_dir_path, 'test', f'character_{i + 1}')
        else:
            raise ValueError(f"The length of dataset is out of range, which idx is {i}")

        if not os.path.exists(chars_class_path):
            os.makedirs(chars_class_path)

        for j, char in enumerate(chars):
            char_path = os.path.join(chars_class_path, f'{i + 1}_{j + 1}.png')
            if os.path.exists(char_path):
                continue
            # 文件夹里还有别的类型的字体
            font = pygame.font.Font(r"./simhei.ttf", 100)
            # 第三个参数为字体颜色,第四个参数为背景颜色。
            rtext = font.render(char, True, (0, 0, 0), (255, 255, 255))
            pygame.image.save(rtext, char_path)


def data_split(num_examples, train_percentage):
    set_seed()
    num_train_examples = int(num_examples * train_percentage)
    num_test_examples = num_examples - num_train_examples

    train_dir_path = os.path.join(args.output_path, 'train')
    test_dir_path = os.path.join(args.output_path, 'test')
    if not os.path.exists(train_dir_path):
        os.makedirs(train_dir_path)
    if not os.path.exists(test_dir_path):
        os.makedirs(test_dir_path)

    train_ids_list = random.sample(range(num_examples), num_train_examples)
    test_ids_list = [idx for idx in range(num_examples) if idx not in train_ids_list]

    assert len(test_ids_list) == num_test_examples
    return train_ids_list, test_ids_list


def main():
    path = './形近字.txt'
    chars_list = load_chars(path)
    num_examples = len(chars_list)
    char2image(chars_list)


if __name__ == '__main__':
    # TODO:其他类型的字体
    parser = argparse.ArgumentParser(description='Make the character-pairs dataset')
    parser.add_argument('--output_path', default='characters', type=str,
                        help='Path to store dataset')
    parser.add_argument('--seed', default=718, type=int,
                        help='Fixed seed for Random package')
    parser.add_argument('--dataset_path', default='形近字.txt', type=str,
                        help='Path of raw dataset')
    parser.add_argument('--train_percentage', default=0.7, type=float,
                        help='Percentage of training set')

    global args
    args = parser.parse_args()

    main()