import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from torch import nn, optim from torch.utils.data import DataLoader from io import StringIO import os import base64 # Import your modules from logistic_regression import LogisticRegressionModel from softmax_regression import SoftmaxRegressionModel from shallow_neural_network import ShallowNeuralNetwork import convolutional_neural_networks from dataset_loader import CustomMNISTDataset from final_project import train_final_model, get_dataset_options, FinalCNN import torchvision.transforms as transforms import torch import matplotlib.pyplot as plt from matplotlib import font_manager import matplotlib.pyplot as plt def number_to_char(number): if 0 <= number <= 9: return str(number) # 0-9 elif 10 <= number <= 35: return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z') elif 36 <= number <= 61: return chr(number + 29) # A-Z (36 -> 'A', 61 -> 'Z') else: return '' def visualize_predictions_svg(model, train_loader, stage): """Visualizes predictions and returns SVG string for Gradio display.""" # Load the Daemon font font_path = './Daemon.otf' # Path to your Daemon font prop = font_manager.FontProperties(fname=font_path) fig, ax = plt.subplots(6, 3, figsize=(12, 16)) # 6 rows and 3 columns for 18 images model.eval() images, labels = next(iter(train_loader)) images, labels = images[:18], labels[:18] # Get 18 images and labels with torch.no_grad(): outputs = model(images) _, predictions = torch.max(outputs, 1) for i in range(18): # Iterate over 18 images ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray') # Convert predictions and labels to characters pred_char = number_to_char(predictions[i].item()) label_char = number_to_char(labels[i].item()) # Display = or != based on prediction if pred_char == label_char: title_text = f"{pred_char} = {label_char}" color = 'green' # Green if correct else: title_text = f"{pred_char} != {label_char}" color = 'red' # Red if incorrect # Set title with Daemon font and color ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color) ax[i // 3, i % 3].axis('off') # Convert the figure to SVG svg_str = figure_to_svg(fig) save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg") # Save SVG to output folder plt.close(fig) return svg_str def figure_to_svg(fig): """Convert a matplotlib figure to SVG string.""" from io import StringIO from matplotlib.backends.backend_svg import FigureCanvasSVG canvas = FigureCanvasSVG(fig) output = StringIO() canvas.print_svg(output) return output.getvalue() def save_svg_to_output_folder(svg_str, filename): """Save the SVG string to the output folder.""" output_path = f'./output/{filename}' # Ensure your output folder exists with open(output_path, 'w') as f: f.write(svg_str) def plot_metrics_svg(losses, accuracies): """Generate training metrics as SVG string.""" fig, ax = plt.subplots(1, 2, figsize=(12, 5)) ax[0].plot(losses, label='Loss', color='red') ax[0].set_title('Training Loss') ax[0].set_xlabel('Epoch') ax[0].set_ylabel('Loss') ax[0].legend() ax[1].plot(accuracies, label='Accuracy', color='green') ax[1].set_title('Training Accuracy') ax[1].set_xlabel('Epoch') ax[1].set_ylabel('Accuracy') ax[1].legend() plt.tight_layout() svg_str = figure_to_svg(fig) save_svg_to_output_folder(svg_str, "training_metrics.svg") # Save metrics SVG to output folder plt.close(fig) return svg_str def train_model_interface(module, dataset_name, epochs=100, lr=0.01): """Train the selected model with the chosen dataset.""" transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) # Load dataset using CustomMNISTDataset train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # Select Model if module == "Logistic Regression": model = LogisticRegressionModel(input_size=1) elif module == "Softmax Regression": model = SoftmaxRegressionModel(input_size=2, num_classes=2) elif module == "Shallow Neural Networks": model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2) elif module == "Deep Networks": import deep_networks model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2) elif module == "Convolutional Neural Networks": model = convolutional_neural_networks.ConvolutionalNeuralNetwork() elif module == "AI Calligraphy": model = FinalCNN() else: return "Invalid module selection", None, None, None, None # Visualize before training before_svg = visualize_predictions_svg(model, train_loader, "Before") # Train the model criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=lr) losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs) # Visualize after training after_svg = visualize_predictions_svg(model, train_loader, "After") # Metrics SVG metrics_svg = plot_metrics_svg(losses, accuracies) return model, losses, accuracies, before_svg, after_svg, metrics_svg def list_datasets(): """List all available datasets dynamically""" dataset_options = get_dataset_options() if not dataset_options: return ["No datasets found"] return dataset_options ### 🎯 Gradio Interface ### def run_module(module, dataset_name, epochs, lr): """Gradio interface callback""" # Train model model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface( module, dataset_name, epochs, lr ) if model is None: return "Error: Invalid selection.", None, None, None, None # Simply pass the SVG strings to Gradio's gr.Image for rendering return ( f"Training completed for {module} with {epochs} epochs.", before_svg, # Pass raw SVG for before training after_svg, # Pass raw SVG for after training metrics_svg # Return training metrics SVG directly ) ### 🌟 Gradio UI ### with gr.Blocks() as app: with gr.Tab("Techniques"): gr.Markdown("### 🧠 Select Model to Train") module_select = gr.Dropdown( choices=[ "AI Calligraphy" ], label="Select Module" ) dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset") epochs = gr.Slider(1, 128, value=100, step=10, label="Epochs") lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate") train_button = gr.Button("Train Model") output = gr.Textbox(label="Training Output") before_svg = gr.HTML(label="Before Training Predictions") after_svg = gr.HTML(label="After Training Predictions") metrics_svg = gr.HTML(label="Metrics") train_button.click( run_module, inputs=[module_select, dataset_list, epochs, lr], outputs=[output, before_svg, after_svg, metrics_svg] ) # Launch Gradio app app.launch()