Spaces:
Sleeping
Sleeping
import os | |
import struct | |
import numpy as np | |
import torch | |
import gzip | |
from PIL import Image, ImageFont, ImageDraw | |
import cv2 | |
import random | |
import string | |
# π Define the HandwrittenFontDataset class | |
class HandwrittenFontDataset(torch.utils.data.Dataset): | |
def __init__(self, font_path, num_samples): | |
self.font_path = font_path | |
self.num_samples = num_samples | |
self.font = ImageFont.truetype(self.font_path, 32) # Font size | |
self.characters = string.digits + string.ascii_uppercase + string.ascii_lowercase | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, index): | |
# Randomly choose a character | |
char = random.choice(self.characters) | |
# Proceed with image creation and processing... | |
# Create image with that character | |
img = Image.new('L', (64, 64), color=255) # Create a blank image (grayscale) | |
draw = ImageDraw.Draw(img) | |
draw.text((10, 10), char, font=self.font, fill=0) # Draw the character | |
# Convert image to numpy array (resize to 28x28 for MNIST format) | |
img = np.array(img) | |
img = preprocess_for_mnist(img) | |
# Convert character to label (integer) | |
label = self.characters.index(char) | |
return torch.tensor(img, dtype=torch.uint8), label | |
# π Resize and preprocess images for MNIST format | |
def preprocess_for_mnist(img): | |
"""Resize image to 28x28 and normalize to 0-255 range.""" | |
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA) | |
img = img.astype(np.uint8) # Convert to unsigned byte | |
return img | |
# π Write images to idx3-ubyte format | |
def write_idx3_ubyte(images, file_path): | |
"""Write images to idx3-ubyte format.""" | |
with open(file_path, 'wb') as f: | |
# Magic number (0x00000801 for image files) | |
f.write(struct.pack(">IIII", 2051, len(images), 28, 28)) | |
# Write image data as unsigned bytes (each pixel in range [0, 255]) | |
for image in images: | |
f.write(image.tobytes()) | |
# π Write labels to idx1-ubyte format | |
def write_idx1_ubyte(labels, file_path): | |
"""Write labels to idx1-ubyte format.""" | |
with open(file_path, 'wb') as f: | |
# Magic number (0x00000801 for label files) | |
f.write(struct.pack(">II", 2049, len(labels))) | |
# Write each label as a byte | |
for label in labels: | |
f.write(struct.pack("B", label)) | |
# π Compress file to .gz format | |
def compress_file(input_path, output_path): | |
"""Compress the idx3 and idx1 files to .gz format.""" | |
with open(input_path, 'rb') as f_in: | |
with gzip.open(output_path, 'wb') as f_out: | |
f_out.writelines(f_in) | |
# π Save dataset in MNIST format | |
def save_mnist_format(images, labels, output_dir): | |
"""Save the dataset in MNIST format to raw/ directory.""" | |
raw_dir = os.path.join(output_dir, "raw") | |
os.makedirs(raw_dir, exist_ok=True) | |
# Prepare file paths | |
train_images_path = os.path.join(raw_dir, "train-images-idx3-ubyte") | |
train_labels_path = os.path.join(raw_dir, "train-labels-idx1-ubyte") | |
# Write uncompressed idx3 and idx1 files | |
write_idx3_ubyte(images, train_images_path) | |
write_idx1_ubyte(labels, train_labels_path) | |
# Compress idx3 and idx1 files into .gz format | |
compress_file(train_images_path, f"{train_images_path}.gz") | |
compress_file(train_labels_path, f"{train_labels_path}.gz") | |
print(f"Dataset saved in MNIST format at {raw_dir}") | |
# β Generate and save the dataset | |
def create_mnist_dataset(font_path, num_samples=4096): | |
"""Generate dataset and save in MNIST format.""" | |
# Get font name without extension | |
font_name = os.path.splitext(os.path.basename(font_path))[0] | |
output_dir = os.path.join("./data", font_name) | |
# Ensure the directory exists | |
os.makedirs(output_dir, exist_ok=True) | |
dataset = HandwrittenFontDataset(font_path, num_samples) | |
images = [] | |
labels = [] | |
for i in range(num_samples): | |
img, label = dataset[i] | |
images.append(img.numpy()) | |
labels.append(label) | |
# Save in MNIST format | |
save_mnist_format(images, labels, output_dir) | |
# π₯ Example usage | |
def choose_font_and_create_dataset(): | |
# List all TTF and OTF files in the root directory | |
font_files = [f for f in os.listdir("./") if f.endswith(".ttf") or f.endswith(".otf")] | |
# Display available fonts for user to choose | |
print("Available fonts:") | |
for i, font_file in enumerate(font_files): | |
print(f"{i+1}. {font_file}") | |
# Get user's choice | |
choice = int(input(f"Choose a font (1-{len(font_files)}): ")) | |
chosen_font = font_files[choice - 1] | |
print(f"Creating dataset using font: {chosen_font}") | |
create_mnist_dataset(chosen_font) | |
# Run the font selection and dataset creation | |
choose_font_and_create_dataset() | |