taellinglin's picture
Update app.py
088873c verified
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader
from io import StringIO
import os
import base64
# Import your modules
from logistic_regression import LogisticRegressionModel
from softmax_regression import SoftmaxRegressionModel
from shallow_neural_network import ShallowNeuralNetwork
import convolutional_neural_networks
from dataset_loader import CustomMNISTDataset
from final_project import train_final_model, get_dataset_options, FinalCNN
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
from matplotlib import font_manager
import matplotlib.pyplot as plt
def number_to_char(number):
if 0 <= number <= 9:
return str(number) # 0-9
elif 10 <= number <= 35:
return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z')
elif 36 <= number <= 61:
return chr(number + 29) # A-Z (36 -> 'A', 61 -> 'Z')
else:
return ''
def visualize_predictions_svg(model, train_loader, stage):
"""Visualizes predictions and returns SVG string for Gradio display."""
# Load the Daemon font
font_path = './Daemon.otf' # Path to your Daemon font
prop = font_manager.FontProperties(fname=font_path)
fig, ax = plt.subplots(6, 3, figsize=(12, 16)) # 6 rows and 3 columns for 18 images
model.eval()
images, labels = next(iter(train_loader))
images, labels = images[:18], labels[:18] # Get 18 images and labels
with torch.no_grad():
outputs = model(images)
_, predictions = torch.max(outputs, 1)
for i in range(18): # Iterate over 18 images
ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray')
# Convert predictions and labels to characters
pred_char = number_to_char(predictions[i].item())
label_char = number_to_char(labels[i].item())
# Display = or != based on prediction
if pred_char == label_char:
title_text = f"{pred_char} = {label_char}"
color = 'green' # Green if correct
else:
title_text = f"{pred_char} != {label_char}"
color = 'red' # Red if incorrect
# Set title with Daemon font and color
ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color)
ax[i // 3, i % 3].axis('off')
# Convert the figure to SVG
svg_str = figure_to_svg(fig)
save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg") # Save SVG to output folder
plt.close(fig)
return svg_str
def figure_to_svg(fig):
"""Convert a matplotlib figure to SVG string."""
from io import StringIO
from matplotlib.backends.backend_svg import FigureCanvasSVG
canvas = FigureCanvasSVG(fig)
output = StringIO()
canvas.print_svg(output)
return output.getvalue()
def save_svg_to_output_folder(svg_str, filename):
"""Save the SVG string to the output folder."""
output_path = f'./output/{filename}' # Ensure your output folder exists
with open(output_path, 'w') as f:
f.write(svg_str)
def plot_metrics_svg(losses, accuracies):
"""Generate training metrics as SVG string."""
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
ax[0].plot(losses, label='Loss', color='red')
ax[0].set_title('Training Loss')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend()
ax[1].plot(accuracies, label='Accuracy', color='green')
ax[1].set_title('Training Accuracy')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
ax[1].legend()
plt.tight_layout()
svg_str = figure_to_svg(fig)
save_svg_to_output_folder(svg_str, "training_metrics.svg") # Save metrics SVG to output folder
plt.close(fig)
return svg_str
def train_model_interface(module, dataset_name, epochs=100, lr=0.01):
"""Train the selected model with the chosen dataset."""
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# Load dataset using CustomMNISTDataset
train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Select Model
if module == "Logistic Regression":
model = LogisticRegressionModel(input_size=1)
elif module == "Softmax Regression":
model = SoftmaxRegressionModel(input_size=2, num_classes=2)
elif module == "Shallow Neural Networks":
model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2)
elif module == "Deep Networks":
import deep_networks
model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2)
elif module == "Convolutional Neural Networks":
model = convolutional_neural_networks.ConvolutionalNeuralNetwork()
elif module == "AI Calligraphy":
model = FinalCNN()
else:
return "Invalid module selection", None, None, None, None
# Visualize before training
before_svg = visualize_predictions_svg(model, train_loader, "Before")
# Train the model
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs)
# Visualize after training
after_svg = visualize_predictions_svg(model, train_loader, "After")
# Metrics SVG
metrics_svg = plot_metrics_svg(losses, accuracies)
return model, losses, accuracies, before_svg, after_svg, metrics_svg
def list_datasets():
"""List all available datasets dynamically"""
dataset_options = get_dataset_options()
if not dataset_options:
return ["No datasets found"]
return dataset_options
### 🎯 Gradio Interface ###
def run_module(module, dataset_name, epochs, lr):
"""Gradio interface callback"""
# Train model
model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface(
module, dataset_name, epochs, lr
)
if model is None:
return "Error: Invalid selection.", None, None, None, None
# Simply pass the SVG strings to Gradio's gr.Image for rendering
return (
f"Training completed for {module} with {epochs} epochs.",
before_svg, # Pass raw SVG for before training
after_svg, # Pass raw SVG for after training
metrics_svg # Return training metrics SVG directly
)
### 🌟 Gradio UI ###
with gr.Blocks() as app:
with gr.Tab("Techniques"):
gr.Markdown("### 🧠 Select Model to Train")
module_select = gr.Dropdown(
choices=[
"AI Calligraphy"
],
label="Select Module"
)
dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset")
epochs = gr.Slider(1, 128, value=100, step=10, label="Epochs")
lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate")
train_button = gr.Button("Train Model")
output = gr.Textbox(label="Training Output")
before_svg = gr.HTML(label="Before Training Predictions")
after_svg = gr.HTML(label="After Training Predictions")
metrics_svg = gr.HTML(label="Metrics")
train_button.click(
run_module,
inputs=[module_select, dataset_list, epochs, lr],
outputs=[output, before_svg, after_svg, metrics_svg]
)
# Launch Gradio app
app.launch()