Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
from dataset_loader import CustomMNISTDataset | |
import os | |
import matplotlib.font_manager as fm | |
# CNN Model | |
# CNN Model with output layer for 62 categories | |
class FinalCNN(nn.Module): | |
def __init__(self): | |
super(FinalCNN, self).__init__() | |
self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=0) | |
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0) | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.fc1 = nn.Linear(32 * 4 * 4, 120) | |
self.fc2 = nn.Linear(120, 84) | |
self.fc3 = nn.Linear(84, 62) # Output layer with 62 units for (0-9, a-z, A-Z) | |
def forward(self, x): | |
x = torch.relu(self.conv1(x)) | |
x = self.pool(x) | |
x = torch.relu(self.conv2(x)) | |
x = self.pool(x) | |
x = x.view(-1, 32 * 4 * 4) | |
x = torch.relu(self.fc1(x)) | |
x = torch.relu(self.fc2(x)) | |
x = self.fc3(x) # Final output | |
return x | |
def plot_loss_accuracy(losses, accuracies): | |
"""Plots Loss vs Accuracy on the same graph.""" | |
plt.figure(figsize=(10, 6)) | |
# Plot Loss | |
plt.plot(losses, color='red', label='Loss (Cost)', linestyle='-', marker='o') | |
# Plot Accuracy | |
plt.plot(accuracies, color='blue', label='Accuracy', linestyle='-', marker='x') | |
plt.title('Training Loss and Accuracy', fontsize=14) | |
plt.xlabel('Epochs', fontsize=12) | |
plt.ylabel('Value', fontsize=12) | |
plt.legend(loc='best') | |
plt.grid(True) | |
# Show the plot | |
plt.savefig("plot.svg") | |
# π₯ Function to choose the dataset dynamically | |
def choose_dataset(dataset_name): | |
"""Choose and load a custom dataset dynamically.""" | |
# β Dynamic path generation | |
base_path = './data' | |
dataset_path = os.path.join(base_path, dataset_name, 'raw') | |
# Validate dataset path | |
if not os.path.exists(dataset_path): | |
raise ValueError(f"β Dataset {dataset_name} not found at {dataset_path}") | |
# β Locate image and label files dynamically | |
image_file = None | |
label_file = None | |
for file in os.listdir(dataset_path): | |
if 'images' in file: | |
image_file = os.path.join(dataset_path, file) | |
elif 'labels' in file: | |
label_file = os.path.join(dataset_path, file) | |
# Ensure both image and label files are found | |
if not image_file or not label_file: | |
raise ValueError(f"β Missing image or label files in {dataset_path}") | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) # Normalize between -1 and 1 | |
]) | |
# β Load the custom dataset with file paths | |
dataset = CustomMNISTDataset(dataset_path=dataset_path, transform=transform) | |
return dataset | |
# Print activation details once | |
def print_activation_details(model, sample_batch): | |
"""Print activation map sizes once before training.""" | |
with torch.no_grad(): | |
x = sample_batch | |
print("\n--- CNN Activation Details (One-time) ---") | |
x = model.conv1(x) | |
print(f"Conv1: {x.shape}") | |
x = model.pool(x) | |
print(f"Pool1: {x.shape}") | |
x = model.conv2(x) | |
print(f"Conv2: {x.shape}") | |
x = model.pool(x) | |
print(f"Pool2: {x.shape}") | |
x = x.view(-1, 32 * 4 * 4) | |
print(f"Flattened: {x.shape}") | |
x = model.fc1(x) | |
print(f"FC1: {x.shape}") | |
x = model.fc2(x) | |
print(f"FC2: {x.shape}") | |
x = model.fc3(x) | |
print(f"Output (Logits): {x.shape}\n") | |
# Training Function | |
def train_final_model(model, criterion, optimizer, train_loader, epochs=256): | |
losses = [] | |
accuracies = [] | |
# Print activation details once before training | |
sample_batch, _ = next(iter(train_loader)) | |
print_activation_details(model, sample_batch) | |
model.train() | |
for epoch in range(epochs): | |
epoch_loss = 0.0 | |
correct, total = 0, 0 | |
# tqdm progress bar | |
with tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as t: | |
for images, labels in t: | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
# Update metrics | |
epoch_loss += loss.item() | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
t.set_postfix(loss=loss.item()) | |
# Store epoch loss and accuracy | |
losses.append(epoch_loss / len(train_loader)) | |
accuracy = 100 * correct / total | |
accuracies.append(accuracy) | |
print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss / len(train_loader):.4f}, Accuracy: {accuracy:.2f}%") | |
# After training, plot the loss and accuracy | |
plot_loss_accuracy(losses, accuracies) | |
return losses, accuracies | |
# Display sample predictions | |
def get_dataset_options(base_path='./data'): | |
"""List all subdirectories in the data directory.""" | |
try: | |
# List all subdirectories in the base_path (data folder) | |
options = [folder for folder in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, folder))] | |
return options | |
except FileNotFoundError: | |
print(f"β Directory {base_path} not found!") | |
return [] | |
def number_to_char(number): | |
if 0 <= number <= 9: | |
return str(number) # 0-9 | |
elif 10 <= number <= 35: | |
return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z') | |
elif 36 <= number <= 61: | |
return chr(number + 65) # A-Z (36 -> 'A', 61 -> 'Z') | |
else: | |
return '' | |
def display_predictions(model, data_loader, output_name, num_samples=6, font_path='./Daemon.otf'): | |
"""Displays sample images with predicted labels""" | |
model.eval() | |
# Load custom font | |
prop = fm.FontProperties(fname=font_path) | |
images, labels = next(iter(data_loader)) | |
with torch.no_grad(): | |
outputs = model(images) | |
_, predictions = torch.max(outputs, 1) | |
# Displaying 6 samples | |
plt.figure(figsize=(12, 6)) | |
for i in range(num_samples): | |
plt.subplot(2, 3, i + 1) | |
plt.imshow(images[i].squeeze(), cmap='gray') | |
# Convert predicted number to corresponding character | |
predicted_char = number_to_char(predictions[i].item()) | |
actual_char = number_to_char(labels[i].item()) | |
# Title with 'Predicted' and 'Actual' both in custom font | |
if(predicted_char == actual_char): | |
plt.title(f'{predicted_char} = {actual_char}', fontsize=84, fontproperties=prop) | |
else: | |
plt.title(f'{predicted_char} != {actual_char}', fontsize=84, fontproperties=prop) | |
plt.axis('off') | |
plt.savefig(output_name) | |
if __name__ == "__main__": | |
# Choose Dataset | |
dataset_options = get_dataset_options() | |
if dataset_options: | |
# Dynamically display dataset options | |
print("Available datasets:") | |
for i, option in enumerate(dataset_options, 1): | |
print(f"{i}. {option}") | |
# User input to choose a dataset | |
dataset_index = int(input(f"Enter the number corresponding to the dataset (1-{len(dataset_options)}): ")) - 1 | |
# Ensure valid selection | |
if 0 <= dataset_index < len(dataset_options): | |
dataset_name = dataset_options[dataset_index] | |
print(f"You selected: {dataset_name}") | |
else: | |
print("β Invalid selection.") | |
dataset_name = None | |
else: | |
print("β No datasets found in the data folder.") | |
dataset_name = None | |
train_dataset = choose_dataset(dataset_name) | |
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | |
# Model, Criterion, and Optimizer | |
model = FinalCNN() | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(model.parameters(), lr=0.005) | |
display_predictions(model, train_loader, output_name="before.svg") | |
# Train the Model | |
losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs=256) | |
# Display sample predictions | |
display_predictions(model, train_loader, output_name="after.svg") | |