digit_detector / app.py
anoexpected's picture
my initial commit
58724e3
# MNIST Handwritten Digit Generation Web App
# TensorFlow/Keras version using VAE and Gradio for Google Colab
# Auto-training version - model trains on startup
import numpy as np
import gradio as gr
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from PIL import Image
import io
import threading
import time
# =============================================================================
# PART 1: VAE MODEL DEFINITION
# =============================================================================
class Sampling(layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
class VAE(Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
if isinstance(data, tuple):
data = data[0]
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
tf.keras.losses.binary_crossentropy(data, reconstruction), axis=-1
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def build_vae(input_shape=(784,), latent_dim=20):
encoder_inputs = layers.Input(shape=input_shape)
x = layers.Dense(400, activation='relu')(encoder_inputs)
x = layers.Dense(400, activation='relu')(x)
z_mean = layers.Dense(latent_dim, name='z_mean')(x)
z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
z = Sampling()([z_mean, z_log_var])
encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
latent_inputs = layers.Input(shape=(latent_dim,))
x = layers.Dense(400, activation='relu')(latent_inputs)
x = layers.Dense(400, activation='relu')(x)
decoder_outputs = layers.Dense(784, activation='sigmoid')(x)
decoder = Model(latent_inputs, decoder_outputs, name='decoder')
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
return vae, encoder, decoder
# =============================================================================
# PART 2: DATA LOADING AND TRAINING
# =============================================================================
encoder = None
decoder = None
digit_latents = None
model_ready = False
training_progress = "Initializing..."
def train_model_background():
global encoder, decoder, digit_latents, model_ready, training_progress
try:
training_progress = "Loading MNIST data..."
print("Loading MNIST data...")
(x_train, y_train), _ = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = x_train.reshape((-1, 784))
x_train = x_train[:10000]
y_train = y_train[:10000]
training_progress = "Building VAE model..."
print("Building VAE model...")
vae, encoder_model, decoder_model = build_vae()
training_progress = "Training VAE model (20 epochs)..."
print("Training VAE model (20 epochs)...")
class ProgressCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
global training_progress
training_progress = f"Training... Epoch {epoch + 1}/20 (Loss: {logs.get('loss', 0):.4f})"
print(f"Epoch {epoch + 1}/20 completed")
history = vae.fit(
x_train, x_train,
epochs=20,
batch_size=128,
verbose=0,
callbacks=[ProgressCallback()]
)
encoder = encoder_model
decoder = decoder_model
training_progress = "Computing digit latent representations..."
print("Computing digit latent representations...")
digit_latents = compute_digit_latents(encoder, x_train, y_train)
training_progress = "✅ Model ready! You can now generate digits."
model_ready = True
print("Model training completed successfully!")
except Exception as e:
training_progress = f"❌ Error training model: {str(e)}"
print(f"Error training model: {str(e)}")
def compute_digit_latents(encoder_model, x_train, y_train):
try:
digit_latents = {i: [] for i in range(10)}
z_means, _, _ = encoder_model.predict(x_train, verbose=0)
for i, label in enumerate(y_train):
digit_latents[label].append(z_means[i])
for i in range(10):
if len(digit_latents[i]) > 0:
digit_latents[i] = np.array(digit_latents[i])
else:
digit_latents[i] = np.random.normal(0, 1, (1, 20))
return digit_latents
except Exception as e:
print(f"Error computing digit latents: {str(e)}")
return None
def get_training_status():
return training_progress
# =============================================================================
# PART 3: IMAGE GENERATION
# =============================================================================
def generate_digit_images(digit, num_images):
global encoder, decoder, digit_latents, model_ready
if not model_ready:
return None, "⏳ Model is still training. Please wait..."
if encoder is None or decoder is None or digit_latents is None:
return None, "❌ Model not ready yet. Please wait for training to complete."
try:
latent_vectors = digit_latents[digit]
if len(latent_vectors) == 0:
selected_latents = np.random.normal(0, 1, (num_images, 20))
else:
if len(latent_vectors) >= num_images:
indices = np.random.choice(len(latent_vectors), num_images, replace=False)
else:
indices = np.random.choice(len(latent_vectors), num_images, replace=True)
selected_latents = latent_vectors[indices]
noise = np.random.normal(0, 0.1, selected_latents.shape)
selected_latents = selected_latents + noise
generated = decoder.predict(selected_latents, verbose=0)
images = (generated.reshape(-1, 28, 28) * 255).astype(np.uint8)
if num_images == 1:
grid_img = Image.fromarray(images[0], mode='L')
else:
cols = min(5, num_images)
rows = (num_images + cols - 1) // cols
grid_width = cols * 28
grid_height = rows * 28
grid_img = Image.new('L', (grid_width, grid_height), color=255)
for i, img in enumerate(images):
row = i // cols
col = i % cols
x = col * 28
y = row * 28
grid_img.paste(Image.fromarray(img, mode='L'), (x, y))
success_msg = f"✅ Generated {len(images)} images of digit {digit}!"
return grid_img, success_msg
except Exception as e:
error_msg = f"❌ Error generating images: {str(e)}"
return None, error_msg
# =============================================================================
# PART 4: GRADIO INTERFACE
# =============================================================================
def create_interface():
with gr.Blocks(title="MNIST VAE Digit Generator", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🔢 TensorFlow VAE Handwritten Digit Generator")
gr.Markdown("Generate MNIST-style handwritten digits using a Variational Autoencoder (VAE).")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Training Status")
training_status = gr.Textbox(
label="Model Status",
value="Initializing...",
interactive=False
)
refresh_btn = gr.Button("🔄 Refresh Status", size="sm")
gr.Markdown("## Generation Controls")
selected_digit = gr.Dropdown(
choices=list(range(10)),
value=0,
label="Select Digit to Generate"
)
num_images = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of Images"
)
generate_btn = gr.Button("🎲 Generate Images", variant="primary", size="lg")
with gr.Column(scale=2):
gr.Markdown("## Generated Images")
output_image = gr.Image(label="Generated Digits", type="pil")
generation_status = gr.Textbox(
label="Generation Status",
value="Model is training... Please wait before generating images.",
interactive=False
)
with gr.Accordion("ℹ️ About this App", open=False):
gr.Markdown("""
This app uses a **Variational Autoencoder (VAE)** to generate handwritten digits similar to the MNIST dataset.
- Wait for training to finish
- Select digit & number of images
- Click 'Generate'
""")
refresh_btn.click(fn=get_training_status, outputs=training_status)
generate_btn.click(fn=generate_digit_images, inputs=[selected_digit, num_images],
outputs=[output_image, generation_status]).then(
fn=get_training_status, outputs=training_status
)
app.load(fn=get_training_status, outputs=training_status)
return app
# =============================================================================
# PART 5: MAIN EXECUTION
# =============================================================================
if __name__ == "__main__":
print("Starting MNIST VAE Digit Generator...")
print("Model will train automatically in the background...")
training_thread = threading.Thread(target=train_model_background, daemon=True)
training_thread.start()
app = create_interface()
app.launch(share=True, debug=True, show_error=True)