|
|
|
import random |
|
import argparse |
|
|
|
import pygame |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
from pypinyin import lazy_pinyin |
|
|
|
|
|
def set_seed(): |
|
random.seed(args.seed) |
|
|
|
|
|
def load_chars(path): |
|
|
|
chars_list = [] |
|
with open(path, 'r+', encoding='utf-8') as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
chars = [char for char in line.strip()] |
|
chars_list.append(chars) |
|
|
|
return chars_list |
|
|
|
|
|
def char2image(chars_list): |
|
|
|
pygame.init() |
|
output_dir_path = os.path.join('./', args.output_path) |
|
if not os.path.exists(output_dir_path): |
|
os.makedirs(output_dir_path) |
|
|
|
train_ids_list, test_ids_list = data_split(len(chars_list), args.train_percentage) |
|
|
|
for i, chars in enumerate(chars_list): |
|
if i in train_ids_list: |
|
chars_class_path = os.path.join(output_dir_path, 'train', f'character_{i + 1}') |
|
elif i in test_ids_list: |
|
chars_class_path = os.path.join(output_dir_path, 'test', f'character_{i + 1}') |
|
else: |
|
raise ValueError(f"The length of dataset is out of range, which idx is {i}") |
|
|
|
if not os.path.exists(chars_class_path): |
|
os.makedirs(chars_class_path) |
|
|
|
for j, char in enumerate(chars): |
|
char_path = os.path.join(chars_class_path, f'{i + 1}_{j + 1}.png') |
|
if os.path.exists(char_path): |
|
continue |
|
|
|
font = pygame.font.Font(r"./simhei.ttf", 100) |
|
|
|
rtext = font.render(char, True, (0, 0, 0), (255, 255, 255)) |
|
pygame.image.save(rtext, char_path) |
|
|
|
|
|
def data_split(num_examples, train_percentage): |
|
set_seed() |
|
num_train_examples = int(num_examples * train_percentage) |
|
num_test_examples = num_examples - num_train_examples |
|
|
|
train_dir_path = os.path.join(args.output_path, 'train') |
|
test_dir_path = os.path.join(args.output_path, 'test') |
|
if not os.path.exists(train_dir_path): |
|
os.makedirs(train_dir_path) |
|
if not os.path.exists(test_dir_path): |
|
os.makedirs(test_dir_path) |
|
|
|
train_ids_list = random.sample(range(num_examples), num_train_examples) |
|
test_ids_list = [idx for idx in range(num_examples) if idx not in train_ids_list] |
|
|
|
assert len(test_ids_list) == num_test_examples |
|
return train_ids_list, test_ids_list |
|
|
|
|
|
def main(): |
|
path = './形近字.txt' |
|
chars_list = load_chars(path) |
|
num_examples = len(chars_list) |
|
char2image(chars_list) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser(description='Make the character-pairs dataset') |
|
parser.add_argument('--output_path', default='characters', type=str, |
|
help='Path to store dataset') |
|
parser.add_argument('--seed', default=718, type=int, |
|
help='Fixed seed for Random package') |
|
parser.add_argument('--dataset_path', default='形近字.txt', type=str, |
|
help='Path of raw dataset') |
|
parser.add_argument('--train_percentage', default=0.7, type=float, |
|
help='Percentage of training set') |
|
|
|
global args |
|
args = parser.parse_args() |
|
|
|
main() |
|
|