File size: 7,655 Bytes
9f495ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e912502
9f495ed
 
 
e912502
 
9f495ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088873c
9f495ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader
from io import StringIO
import os
import  base64
# Import your modules
from logistic_regression import LogisticRegressionModel
from softmax_regression import SoftmaxRegressionModel
from shallow_neural_network import ShallowNeuralNetwork
import convolutional_neural_networks
from dataset_loader import CustomMNISTDataset
from final_project import train_final_model, get_dataset_options, FinalCNN
import torchvision.transforms as transforms

import torch
import matplotlib.pyplot as plt
from matplotlib import font_manager
import matplotlib.pyplot as plt
def number_to_char(number):
    if 0 <= number <= 9:
        return str(number)  # 0-9
    elif 10 <= number <= 35:
        return chr(number + 87)  # a-z (10 -> 'a', 35 -> 'z')
    elif 36 <= number <= 61:
        return chr(number + 29)  # A-Z (36 -> 'A', 61 -> 'Z')
    else:
        return ''



def visualize_predictions_svg(model, train_loader, stage):
    """Visualizes predictions and returns SVG string for Gradio display.""" 
    # Load the Daemon font
    font_path = './Daemon.otf'  # Path to your Daemon font
    prop = font_manager.FontProperties(fname=font_path)

    fig, ax = plt.subplots(6, 3, figsize=(12, 16))  # 6 rows and 3 columns for 18 images

    model.eval()
    images, labels = next(iter(train_loader))
    images, labels = images[:18], labels[:18]  # Get 18 images and labels

    with torch.no_grad():
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)

    for i in range(18):  # Iterate over 18 images
        ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray')

        # Convert predictions and labels to characters
        pred_char = number_to_char(predictions[i].item())
        label_char = number_to_char(labels[i].item())

        # Display = or != based on prediction
        if pred_char == label_char:
            title_text = f"{pred_char} = {label_char}"
            color = 'green'  # Green if correct
        else:
            title_text = f"{pred_char} != {label_char}"
            color = 'red'  # Red if incorrect

        # Set title with Daemon font and color
        ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color)
        ax[i // 3, i % 3].axis('off')


    # Convert the figure to SVG
    svg_str = figure_to_svg(fig)
    save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg")  # Save SVG to output folder
    plt.close(fig)

    return svg_str

def figure_to_svg(fig):
    """Convert a matplotlib figure to SVG string."""
    from io import StringIO
    from matplotlib.backends.backend_svg import FigureCanvasSVG
    canvas = FigureCanvasSVG(fig)
    output = StringIO()
    canvas.print_svg(output)
    return output.getvalue()

def save_svg_to_output_folder(svg_str, filename):
    """Save the SVG string to the output folder."""
    output_path = f'./output/{filename}'  # Ensure your output folder exists
    with open(output_path, 'w') as f:
        f.write(svg_str)


def plot_metrics_svg(losses, accuracies):
    """Generate training metrics as SVG string."""
    fig, ax = plt.subplots(1, 2, figsize=(12, 5))

    ax[0].plot(losses, label='Loss', color='red')
    ax[0].set_title('Training Loss')
    ax[0].set_xlabel('Epoch')
    ax[0].set_ylabel('Loss')
    ax[0].legend()

    ax[1].plot(accuracies, label='Accuracy', color='green')
    ax[1].set_title('Training Accuracy')
    ax[1].set_xlabel('Epoch')
    ax[1].set_ylabel('Accuracy')
    ax[1].legend()

    plt.tight_layout()
    svg_str = figure_to_svg(fig)
    save_svg_to_output_folder(svg_str, "training_metrics.svg")  # Save metrics SVG to output folder
    plt.close(fig)

    return svg_str

def train_model_interface(module, dataset_name, epochs=100, lr=0.01):
    """Train the selected model with the chosen dataset."""
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    # Load dataset using CustomMNISTDataset
    train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    # Select Model
    if module == "Logistic Regression":
        model = LogisticRegressionModel(input_size=1)
    elif module == "Softmax Regression":
        model = SoftmaxRegressionModel(input_size=2, num_classes=2)
    elif module == "Shallow Neural Networks":
        model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2)
    elif module == "Deep Networks":
        import deep_networks
        model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2)
    elif module == "Convolutional Neural Networks":
        model = convolutional_neural_networks.ConvolutionalNeuralNetwork()
    elif module == "AI Calligraphy":
        model = FinalCNN()
    else:
        return "Invalid module selection", None, None, None, None

    # Visualize before training
    before_svg = visualize_predictions_svg(model, train_loader, "Before")

    # Train the model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs)

    # Visualize after training
    after_svg = visualize_predictions_svg(model, train_loader, "After")

    # Metrics SVG
    metrics_svg = plot_metrics_svg(losses, accuracies)

    return model, losses, accuracies, before_svg, after_svg, metrics_svg


def list_datasets():
    """List all available datasets dynamically"""
    dataset_options = get_dataset_options()
    if not dataset_options:
        return ["No datasets found"]
    return dataset_options

### 🎯 Gradio Interface ###
def run_module(module, dataset_name, epochs, lr):
    """Gradio interface callback"""
    # Train model
    model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface(
        module, dataset_name, epochs, lr
    )

    if model is None:
        return "Error: Invalid selection.", None, None, None, None

    # Simply pass the SVG strings to Gradio's gr.Image for rendering
    return (
        f"Training completed for {module} with {epochs} epochs.",
        before_svg,  # Pass raw SVG for before training
        after_svg,   # Pass raw SVG for after training
        metrics_svg  # Return training metrics SVG directly
    )

### 🌟 Gradio UI ###
with gr.Blocks() as app:
    with gr.Tab("Techniques"):
        gr.Markdown("### 🧠 Select Model to Train")

        module_select = gr.Dropdown(
            choices=[
                "AI Calligraphy"
            ],
            label="Select Module"
        )

        dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset")
        epochs = gr.Slider(1, 128, value=100, step=10, label="Epochs")
        lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate")

        train_button = gr.Button("Train Model")

        output = gr.Textbox(label="Training Output")
        before_svg = gr.HTML(label="Before Training Predictions")
        after_svg = gr.HTML(label="After Training Predictions")
        metrics_svg = gr.HTML(label="Metrics")

        train_button.click(
            run_module,
            inputs=[module_select, dataset_list, epochs, lr],
            outputs=[output, before_svg, after_svg, metrics_svg]
        )

# Launch Gradio app
app.launch()