Spaces:
Sleeping
Sleeping
File size: 7,655 Bytes
9f495ed e912502 9f495ed e912502 9f495ed 088873c 9f495ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader
from io import StringIO
import os
import base64
# Import your modules
from logistic_regression import LogisticRegressionModel
from softmax_regression import SoftmaxRegressionModel
from shallow_neural_network import ShallowNeuralNetwork
import convolutional_neural_networks
from dataset_loader import CustomMNISTDataset
from final_project import train_final_model, get_dataset_options, FinalCNN
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
from matplotlib import font_manager
import matplotlib.pyplot as plt
def number_to_char(number):
if 0 <= number <= 9:
return str(number) # 0-9
elif 10 <= number <= 35:
return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z')
elif 36 <= number <= 61:
return chr(number + 29) # A-Z (36 -> 'A', 61 -> 'Z')
else:
return ''
def visualize_predictions_svg(model, train_loader, stage):
"""Visualizes predictions and returns SVG string for Gradio display."""
# Load the Daemon font
font_path = './Daemon.otf' # Path to your Daemon font
prop = font_manager.FontProperties(fname=font_path)
fig, ax = plt.subplots(6, 3, figsize=(12, 16)) # 6 rows and 3 columns for 18 images
model.eval()
images, labels = next(iter(train_loader))
images, labels = images[:18], labels[:18] # Get 18 images and labels
with torch.no_grad():
outputs = model(images)
_, predictions = torch.max(outputs, 1)
for i in range(18): # Iterate over 18 images
ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray')
# Convert predictions and labels to characters
pred_char = number_to_char(predictions[i].item())
label_char = number_to_char(labels[i].item())
# Display = or != based on prediction
if pred_char == label_char:
title_text = f"{pred_char} = {label_char}"
color = 'green' # Green if correct
else:
title_text = f"{pred_char} != {label_char}"
color = 'red' # Red if incorrect
# Set title with Daemon font and color
ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color)
ax[i // 3, i % 3].axis('off')
# Convert the figure to SVG
svg_str = figure_to_svg(fig)
save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg") # Save SVG to output folder
plt.close(fig)
return svg_str
def figure_to_svg(fig):
"""Convert a matplotlib figure to SVG string."""
from io import StringIO
from matplotlib.backends.backend_svg import FigureCanvasSVG
canvas = FigureCanvasSVG(fig)
output = StringIO()
canvas.print_svg(output)
return output.getvalue()
def save_svg_to_output_folder(svg_str, filename):
"""Save the SVG string to the output folder."""
output_path = f'./output/{filename}' # Ensure your output folder exists
with open(output_path, 'w') as f:
f.write(svg_str)
def plot_metrics_svg(losses, accuracies):
"""Generate training metrics as SVG string."""
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
ax[0].plot(losses, label='Loss', color='red')
ax[0].set_title('Training Loss')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend()
ax[1].plot(accuracies, label='Accuracy', color='green')
ax[1].set_title('Training Accuracy')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
ax[1].legend()
plt.tight_layout()
svg_str = figure_to_svg(fig)
save_svg_to_output_folder(svg_str, "training_metrics.svg") # Save metrics SVG to output folder
plt.close(fig)
return svg_str
def train_model_interface(module, dataset_name, epochs=100, lr=0.01):
"""Train the selected model with the chosen dataset."""
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# Load dataset using CustomMNISTDataset
train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Select Model
if module == "Logistic Regression":
model = LogisticRegressionModel(input_size=1)
elif module == "Softmax Regression":
model = SoftmaxRegressionModel(input_size=2, num_classes=2)
elif module == "Shallow Neural Networks":
model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2)
elif module == "Deep Networks":
import deep_networks
model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2)
elif module == "Convolutional Neural Networks":
model = convolutional_neural_networks.ConvolutionalNeuralNetwork()
elif module == "AI Calligraphy":
model = FinalCNN()
else:
return "Invalid module selection", None, None, None, None
# Visualize before training
before_svg = visualize_predictions_svg(model, train_loader, "Before")
# Train the model
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs)
# Visualize after training
after_svg = visualize_predictions_svg(model, train_loader, "After")
# Metrics SVG
metrics_svg = plot_metrics_svg(losses, accuracies)
return model, losses, accuracies, before_svg, after_svg, metrics_svg
def list_datasets():
"""List all available datasets dynamically"""
dataset_options = get_dataset_options()
if not dataset_options:
return ["No datasets found"]
return dataset_options
### 🎯 Gradio Interface ###
def run_module(module, dataset_name, epochs, lr):
"""Gradio interface callback"""
# Train model
model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface(
module, dataset_name, epochs, lr
)
if model is None:
return "Error: Invalid selection.", None, None, None, None
# Simply pass the SVG strings to Gradio's gr.Image for rendering
return (
f"Training completed for {module} with {epochs} epochs.",
before_svg, # Pass raw SVG for before training
after_svg, # Pass raw SVG for after training
metrics_svg # Return training metrics SVG directly
)
### 🌟 Gradio UI ###
with gr.Blocks() as app:
with gr.Tab("Techniques"):
gr.Markdown("### 🧠 Select Model to Train")
module_select = gr.Dropdown(
choices=[
"AI Calligraphy"
],
label="Select Module"
)
dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset")
epochs = gr.Slider(1, 128, value=100, step=10, label="Epochs")
lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate")
train_button = gr.Button("Train Model")
output = gr.Textbox(label="Training Output")
before_svg = gr.HTML(label="Before Training Predictions")
after_svg = gr.HTML(label="After Training Predictions")
metrics_svg = gr.HTML(label="Metrics")
train_button.click(
run_module,
inputs=[module_select, dataset_list, epochs, lr],
outputs=[output, before_svg, after_svg, metrics_svg]
)
# Launch Gradio app
app.launch()
|