from torch.utils.data import Dataset from create_fontstyle import fontstyle_list import torch from PIL import Image from PIL import ImageFont from PIL import ImageDraw import glob import random import os font_folder = 'font' font_name = ['arial', 'bodoni','calibri','futura','heveltica','times-new-roman'] fonts = fontstyle_list(font_folder, font_name) class PrintedMNIST(Dataset): """ Generate digital mnist dataset for digits recognition """ def __init__(self, samples, random_state, transform = None): self.samples = samples self.random_state = random_state self.transfrom = transform self.fonts = fonts random.seed(random_state) def __len__(self): return self.samples def __getitem__(self, index): color = random.randint(200,255) #Generate image img = Image.new("L",(256, 256)) label = random.randint(0,9) size = random.randint(180, 220) x = random.randint(60, 80) y = random.randint(30, 60) draw = ImageDraw.Draw(img) #Choose random font style in font style list font = ImageFont.truetype(random.choice(self.fonts), size) draw.text((x,y), str(label), color, font = fonts) img = img.resize((28,28), Image.BILINEAR) if self.transfrom: img = self.transfrom(img) return img, label class AddSPNoise(object): def __init__(self, prob): self.prob = prob def __call__(self, tensor): sp = (torch.rand(tensor.size()) < self.prob) * tensor.max() return tensor + sp def __repr__(self): return self.__class__.__name__ + "(prob={0})".format(self.prob) class AddGaussianNoise(object): def __init__(self, mean=0.0, std=1.0): self.mean = mean self.std = std def __call__(self, tensor): return tensor + torch.randn(tensor.size()) * self.std + self.mean def __repr__(self): return self.__class__.__name__ + "(mean={0}, std={1})".format( self.mean, self.std )