File size: 3,211 Bytes
3b8ed21
828ae30
64da2cc
 
 
3b8ed21
 
 
734e772
3b8ed21
64da2cc
 
 
 
734e772
64da2cc
 
 
 
 
 
 
 
 
 
 
 
 
734e772
64da2cc
 
734e772
64da2cc
da48f05
64da2cc
734e772
64da2cc
 
3b8ed21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm
import numpy as np
from PIL import Image, ImageOps
import random

# Function to download and extract the NIST SD19 dataset
def download_nist_sd19(url, dest_folder):
    if not os.path.exists(dest_folder):
        os.makedirs(dest_folder)
    filename = os.path.join(dest_folder, url.split('/')[-1])

    if not os.path.exists(filename):
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        with open(filename, 'wb') as file, tqdm(
            desc=filename,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for data in response.iter_content(chunk_size=1024):
                size = file.write(data)
                bar.update(size)
    
    with ZipFile(filename, 'r') as zip_ref:
        zip_ref.extractall(dest_folder)

# URL to download the NIST SD19 dataset
nist_sd19_url = "https://s3.amazonaws.com/nist-srd/SD19/by_class.zip"
download_folder = "nist_sd19"

# Download and extract the dataset
download_nist_sd19(nist_sd19_url, download_folder)

# Path to the NIST SD19 dataset
nist_dataset_path = os.path.join(download_folder, "hsf_0")

# Function to load the dataset
def load_nist_dataset(path):
    images = []
    labels = []
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".png"):
                img_path = os.path.join(root, file)
                label = file.split('_')[1]  # Assuming the label is in the filename
                images.append(img_path)
                labels.append(label)
    return images, labels

# Load the dataset
images, labels = load_nist_dataset(nist_dataset_path)

# Function to generate handwritten text image
def generate_handwritten_text(input_text):
    char_images = []
    
    for char in input_text:
        matching_images = [img for img, label in zip(images, labels) if label == char.upper()]
        if matching_images:
            char_image_path = random.choice(matching_images)
            char_image = Image.open(char_image_path).convert('L')
            
            # Add padding to each character image to make it 28x28
            char_image = ImageOps.pad(char_image, (28, 28), color='white')
            
            char_images.append(char_image)
        else:
            # If no matching image is found, create a blank 28x28 image
            char_images.append(Image.new('L', (28, 28), color=255))
    
    img_width = 28 * len(input_text)
    img_height = 28
    output_image = Image.new('L', (img_width, img_height), color=255)
    
    for idx, char_image in enumerate(char_images):
        output_image.paste(char_image, (idx * 28, 0))
    
    return output_image

# Gradio interface
interface = gr.Interface(fn=generate_handwritten_text, 
                         inputs="text", 
                         outputs="image",
                         title="NIST Handwritten Text Generator",
                         description="Enter text to generate a handwritten text image using the NIST SD19 dataset.")

interface.launch()