Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from torch import nn, optim | |
| from torch.utils.data import DataLoader | |
| from io import StringIO | |
| import os | |
| import base64 | |
| # Import your modules | |
| from logistic_regression import LogisticRegressionModel | |
| from softmax_regression import SoftmaxRegressionModel | |
| from shallow_neural_network import ShallowNeuralNetwork | |
| import convolutional_neural_networks | |
| from dataset_loader import CustomMNISTDataset | |
| from final_project import train_final_model, get_dataset_options, FinalCNN | |
| import torchvision.transforms as transforms | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from matplotlib import font_manager | |
| import matplotlib.pyplot as plt | |
| def number_to_char(number): | |
| if 0 <= number <= 9: | |
| return str(number) # 0-9 | |
| elif 10 <= number <= 35: | |
| return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z') | |
| elif 36 <= number <= 61: | |
| return chr(number + 29) # A-Z (36 -> 'A', 61 -> 'Z') | |
| else: | |
| return '' | |
| def visualize_predictions_svg(model, train_loader, stage): | |
| """Visualizes predictions and returns SVG string for Gradio display.""" | |
| # Load the Daemon font | |
| font_path = './Daemon.otf' # Path to your Daemon font | |
| prop = font_manager.FontProperties(fname=font_path) | |
| fig, ax = plt.subplots(6, 3, figsize=(12, 16)) # 6 rows and 3 columns for 18 images | |
| model.eval() | |
| images, labels = next(iter(train_loader)) | |
| images, labels = images[:18], labels[:18] # Get 18 images and labels | |
| with torch.no_grad(): | |
| outputs = model(images) | |
| _, predictions = torch.max(outputs, 1) | |
| for i in range(18): # Iterate over 18 images | |
| ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray') | |
| # Convert predictions and labels to characters | |
| pred_char = number_to_char(predictions[i].item()) | |
| label_char = number_to_char(labels[i].item()) | |
| # Display = or != based on prediction | |
| if pred_char == label_char: | |
| title_text = f"{pred_char} = {label_char}" | |
| color = 'green' # Green if correct | |
| else: | |
| title_text = f"{pred_char} != {label_char}" | |
| color = 'red' # Red if incorrect | |
| # Set title with Daemon font and color | |
| ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color) | |
| ax[i // 3, i % 3].axis('off') | |
| # Convert the figure to SVG | |
| svg_str = figure_to_svg(fig) | |
| save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg") # Save SVG to output folder | |
| plt.close(fig) | |
| return svg_str | |
| def figure_to_svg(fig): | |
| """Convert a matplotlib figure to SVG string.""" | |
| from io import StringIO | |
| from matplotlib.backends.backend_svg import FigureCanvasSVG | |
| canvas = FigureCanvasSVG(fig) | |
| output = StringIO() | |
| canvas.print_svg(output) | |
| return output.getvalue() | |
| def save_svg_to_output_folder(svg_str, filename): | |
| """Save the SVG string to the output folder.""" | |
| output_path = f'./output/{filename}' # Ensure your output folder exists | |
| with open(output_path, 'w') as f: | |
| f.write(svg_str) | |
| def plot_metrics_svg(losses, accuracies): | |
| """Generate training metrics as SVG string.""" | |
| fig, ax = plt.subplots(1, 2, figsize=(12, 5)) | |
| ax[0].plot(losses, label='Loss', color='red') | |
| ax[0].set_title('Training Loss') | |
| ax[0].set_xlabel('Epoch') | |
| ax[0].set_ylabel('Loss') | |
| ax[0].legend() | |
| ax[1].plot(accuracies, label='Accuracy', color='green') | |
| ax[1].set_title('Training Accuracy') | |
| ax[1].set_xlabel('Epoch') | |
| ax[1].set_ylabel('Accuracy') | |
| ax[1].legend() | |
| plt.tight_layout() | |
| svg_str = figure_to_svg(fig) | |
| save_svg_to_output_folder(svg_str, "training_metrics.svg") # Save metrics SVG to output folder | |
| plt.close(fig) | |
| return svg_str | |
| def train_model_interface(module, dataset_name, epochs=100, lr=0.01): | |
| """Train the selected model with the chosen dataset.""" | |
| transform = transforms.Compose([ | |
| transforms.Resize((28, 28)), | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ]) | |
| # Load dataset using CustomMNISTDataset | |
| train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | |
| # Select Model | |
| if module == "Logistic Regression": | |
| model = LogisticRegressionModel(input_size=1) | |
| elif module == "Softmax Regression": | |
| model = SoftmaxRegressionModel(input_size=2, num_classes=2) | |
| elif module == "Shallow Neural Networks": | |
| model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2) | |
| elif module == "Deep Networks": | |
| import deep_networks | |
| model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2) | |
| elif module == "Convolutional Neural Networks": | |
| model = convolutional_neural_networks.ConvolutionalNeuralNetwork() | |
| elif module == "AI Calligraphy": | |
| model = FinalCNN() | |
| else: | |
| return "Invalid module selection", None, None, None, None | |
| # Visualize before training | |
| before_svg = visualize_predictions_svg(model, train_loader, "Before") | |
| # Train the model | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.SGD(model.parameters(), lr=lr) | |
| losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs) | |
| # Visualize after training | |
| after_svg = visualize_predictions_svg(model, train_loader, "After") | |
| # Metrics SVG | |
| metrics_svg = plot_metrics_svg(losses, accuracies) | |
| return model, losses, accuracies, before_svg, after_svg, metrics_svg | |
| def list_datasets(): | |
| """List all available datasets dynamically""" | |
| dataset_options = get_dataset_options() | |
| if not dataset_options: | |
| return ["No datasets found"] | |
| return dataset_options | |
| ### 🎯 Gradio Interface ### | |
| def run_module(module, dataset_name, epochs, lr): | |
| """Gradio interface callback""" | |
| # Train model | |
| model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface( | |
| module, dataset_name, epochs, lr | |
| ) | |
| if model is None: | |
| return "Error: Invalid selection.", None, None, None, None | |
| # Simply pass the SVG strings to Gradio's gr.Image for rendering | |
| return ( | |
| f"Training completed for {module} with {epochs} epochs.", | |
| before_svg, # Pass raw SVG for before training | |
| after_svg, # Pass raw SVG for after training | |
| metrics_svg # Return training metrics SVG directly | |
| ) | |
| ### 🌟 Gradio UI ### | |
| with gr.Blocks() as app: | |
| with gr.Tab("Techniques"): | |
| gr.Markdown("### 🧠 Select Model to Train") | |
| module_select = gr.Dropdown( | |
| choices=[ | |
| "AI Calligraphy" | |
| ], | |
| label="Select Module" | |
| ) | |
| dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset") | |
| epochs = gr.Slider(1, 128, value=100, step=10, label="Epochs") | |
| lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate") | |
| train_button = gr.Button("Train Model") | |
| output = gr.Textbox(label="Training Output") | |
| before_svg = gr.HTML(label="Before Training Predictions") | |
| after_svg = gr.HTML(label="After Training Predictions") | |
| metrics_svg = gr.HTML(label="Metrics") | |
| train_button.click( | |
| run_module, | |
| inputs=[module_select, dataset_list, epochs, lr], | |
| outputs=[output, before_svg, after_svg, metrics_svg] | |
| ) | |
| # Launch Gradio app | |
| app.launch() | |