Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from torch import nn, optim | |
from torch.utils.data import DataLoader | |
from io import StringIO | |
import os | |
import base64 | |
# Import your modules | |
from logistic_regression import LogisticRegressionModel | |
from softmax_regression import SoftmaxRegressionModel | |
from shallow_neural_network import ShallowNeuralNetwork | |
import convolutional_neural_networks | |
from dataset_loader import CustomMNISTDataset | |
from final_project import train_final_model, get_dataset_options, FinalCNN | |
import torchvision.transforms as transforms | |
import torch | |
import matplotlib.pyplot as plt | |
from matplotlib import font_manager | |
import matplotlib.pyplot as plt | |
def number_to_char(number): | |
if 0 <= number <= 9: | |
return str(number) # 0-9 | |
elif 10 <= number <= 35: | |
return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z') | |
elif 36 <= number <= 61: | |
return chr(number + 29) # A-Z (36 -> 'A', 61 -> 'Z') | |
else: | |
return '' | |
def visualize_predictions_svg(model, train_loader, stage): | |
"""Visualizes predictions and returns SVG string for Gradio display.""" | |
# Load the Daemon font | |
font_path = './Daemon.otf' # Path to your Daemon font | |
prop = font_manager.FontProperties(fname=font_path) | |
fig, ax = plt.subplots(6, 3, figsize=(12, 16)) # 6 rows and 3 columns for 18 images | |
model.eval() | |
images, labels = next(iter(train_loader)) | |
images, labels = images[:18], labels[:18] # Get 18 images and labels | |
with torch.no_grad(): | |
outputs = model(images) | |
_, predictions = torch.max(outputs, 1) | |
for i in range(18): # Iterate over 18 images | |
ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray') | |
# Convert predictions and labels to characters | |
pred_char = number_to_char(predictions[i].item()) | |
label_char = number_to_char(labels[i].item()) | |
# Display = or != based on prediction | |
if pred_char == label_char: | |
title_text = f"{pred_char} = {label_char}" | |
color = 'green' # Green if correct | |
else: | |
title_text = f"{pred_char} != {label_char}" | |
color = 'red' # Red if incorrect | |
# Set title with Daemon font and color | |
ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color) | |
ax[i // 3, i % 3].axis('off') | |
# Convert the figure to SVG | |
svg_str = figure_to_svg(fig) | |
save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg") # Save SVG to output folder | |
plt.close(fig) | |
return svg_str | |
def figure_to_svg(fig): | |
"""Convert a matplotlib figure to SVG string.""" | |
from io import StringIO | |
from matplotlib.backends.backend_svg import FigureCanvasSVG | |
canvas = FigureCanvasSVG(fig) | |
output = StringIO() | |
canvas.print_svg(output) | |
return output.getvalue() | |
def save_svg_to_output_folder(svg_str, filename): | |
"""Save the SVG string to the output folder.""" | |
output_path = f'./output/{filename}' # Ensure your output folder exists | |
with open(output_path, 'w') as f: | |
f.write(svg_str) | |
def plot_metrics_svg(losses, accuracies): | |
"""Generate training metrics as SVG string.""" | |
fig, ax = plt.subplots(1, 2, figsize=(12, 5)) | |
ax[0].plot(losses, label='Loss', color='red') | |
ax[0].set_title('Training Loss') | |
ax[0].set_xlabel('Epoch') | |
ax[0].set_ylabel('Loss') | |
ax[0].legend() | |
ax[1].plot(accuracies, label='Accuracy', color='green') | |
ax[1].set_title('Training Accuracy') | |
ax[1].set_xlabel('Epoch') | |
ax[1].set_ylabel('Accuracy') | |
ax[1].legend() | |
plt.tight_layout() | |
svg_str = figure_to_svg(fig) | |
save_svg_to_output_folder(svg_str, "training_metrics.svg") # Save metrics SVG to output folder | |
plt.close(fig) | |
return svg_str | |
def train_model_interface(module, dataset_name, epochs=100, lr=0.01): | |
"""Train the selected model with the chosen dataset.""" | |
transform = transforms.Compose([ | |
transforms.Resize((28, 28)), | |
transforms.Grayscale(num_output_channels=1), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
# Load dataset using CustomMNISTDataset | |
train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | |
# Select Model | |
if module == "Logistic Regression": | |
model = LogisticRegressionModel(input_size=1) | |
elif module == "Softmax Regression": | |
model = SoftmaxRegressionModel(input_size=2, num_classes=2) | |
elif module == "Shallow Neural Networks": | |
model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2) | |
elif module == "Deep Networks": | |
import deep_networks | |
model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2) | |
elif module == "Convolutional Neural Networks": | |
model = convolutional_neural_networks.ConvolutionalNeuralNetwork() | |
elif module == "AI Calligraphy": | |
model = FinalCNN() | |
else: | |
return "Invalid module selection", None, None, None, None | |
# Visualize before training | |
before_svg = visualize_predictions_svg(model, train_loader, "Before") | |
# Train the model | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(model.parameters(), lr=lr) | |
losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs) | |
# Visualize after training | |
after_svg = visualize_predictions_svg(model, train_loader, "After") | |
# Metrics SVG | |
metrics_svg = plot_metrics_svg(losses, accuracies) | |
return model, losses, accuracies, before_svg, after_svg, metrics_svg | |
def list_datasets(): | |
"""List all available datasets dynamically""" | |
dataset_options = get_dataset_options() | |
if not dataset_options: | |
return ["No datasets found"] | |
return dataset_options | |
### 🎯 Gradio Interface ### | |
def run_module(module, dataset_name, epochs, lr): | |
"""Gradio interface callback""" | |
# Train model | |
model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface( | |
module, dataset_name, epochs, lr | |
) | |
if model is None: | |
return "Error: Invalid selection.", None, None, None, None | |
# Simply pass the SVG strings to Gradio's gr.Image for rendering | |
return ( | |
f"Training completed for {module} with {epochs} epochs.", | |
before_svg, # Pass raw SVG for before training | |
after_svg, # Pass raw SVG for after training | |
metrics_svg # Return training metrics SVG directly | |
) | |
### 🌟 Gradio UI ### | |
with gr.Blocks() as app: | |
with gr.Tab("Techniques"): | |
gr.Markdown("### 🧠 Select Model to Train") | |
module_select = gr.Dropdown( | |
choices=[ | |
"AI Calligraphy" | |
], | |
label="Select Module" | |
) | |
dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset") | |
epochs = gr.Slider(1, 128, value=100, step=10, label="Epochs") | |
lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate") | |
train_button = gr.Button("Train Model") | |
output = gr.Textbox(label="Training Output") | |
before_svg = gr.HTML(label="Before Training Predictions") | |
after_svg = gr.HTML(label="After Training Predictions") | |
metrics_svg = gr.HTML(label="Metrics") | |
train_button.click( | |
run_module, | |
inputs=[module_select, dataset_list, epochs, lr], | |
outputs=[output, before_svg, after_svg, metrics_svg] | |
) | |
# Launch Gradio app | |
app.launch() | |