File size: 3,211 Bytes
3b8ed21 828ae30 64da2cc 3b8ed21 734e772 3b8ed21 64da2cc 734e772 64da2cc 734e772 64da2cc 734e772 64da2cc da48f05 64da2cc 734e772 64da2cc 3b8ed21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm
import numpy as np
from PIL import Image, ImageOps
import random
# Function to download and extract the NIST SD19 dataset
def download_nist_sd19(url, dest_folder):
if not os.path.exists(dest_folder):
os.makedirs(dest_folder)
filename = os.path.join(dest_folder, url.split('/')[-1])
if not os.path.exists(filename):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
with ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall(dest_folder)
# URL to download the NIST SD19 dataset
nist_sd19_url = "https://s3.amazonaws.com/nist-srd/SD19/by_class.zip"
download_folder = "nist_sd19"
# Download and extract the dataset
download_nist_sd19(nist_sd19_url, download_folder)
# Path to the NIST SD19 dataset
nist_dataset_path = os.path.join(download_folder, "hsf_0")
# Function to load the dataset
def load_nist_dataset(path):
images = []
labels = []
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(".png"):
img_path = os.path.join(root, file)
label = file.split('_')[1] # Assuming the label is in the filename
images.append(img_path)
labels.append(label)
return images, labels
# Load the dataset
images, labels = load_nist_dataset(nist_dataset_path)
# Function to generate handwritten text image
def generate_handwritten_text(input_text):
char_images = []
for char in input_text:
matching_images = [img for img, label in zip(images, labels) if label == char.upper()]
if matching_images:
char_image_path = random.choice(matching_images)
char_image = Image.open(char_image_path).convert('L')
# Add padding to each character image to make it 28x28
char_image = ImageOps.pad(char_image, (28, 28), color='white')
char_images.append(char_image)
else:
# If no matching image is found, create a blank 28x28 image
char_images.append(Image.new('L', (28, 28), color=255))
img_width = 28 * len(input_text)
img_height = 28
output_image = Image.new('L', (img_width, img_height), color=255)
for idx, char_image in enumerate(char_images):
output_image.paste(char_image, (idx * 28, 0))
return output_image
# Gradio interface
interface = gr.Interface(fn=generate_handwritten_text,
inputs="text",
outputs="image",
title="NIST Handwritten Text Generator",
description="Enter text to generate a handwritten text image using the NIST SD19 dataset.")
interface.launch()
|