Coconut-MNIST / src /utils.py
ymlin105's picture
feat: initial implementation of MNIST Hybrid SVD-CNN core
58839b6
import torch
import torchvision.transforms as T
import torchvision
import numpy as np
import os
import pickle
import ssl
from src.hybrid_model import SimpleCNN
from src import config
import cv2
def load_data_split(dataset_name="mnist", train=True, digits=None, flatten=False):
"""
Unified entry point for data loading.
Supports: MNIST, Fashion-MNIST, and custom digit filtering (e.g., [3, 8]).
"""
# Bypass SSL verification issues for dataset downloads
ssl._create_default_https_context = ssl._create_unverified_context
transform = T.Compose([T.ToTensor()])
if dataset_name.lower() == "mnist":
dataset = torchvision.datasets.MNIST(config.DATA_DIR, train=train, download=True, transform=transform)
elif dataset_name.lower() == "fashion":
dataset = torchvision.datasets.FashionMNIST(config.DATA_DIR, train=train, download=True, transform=transform)
else:
raise ValueError(f"Unknown dataset: {dataset_name}")
X = dataset.data.float() / 255.0
y = dataset.targets
# Filter for specific digits if requested (e.g., [3, 8] for binary analysis)
if digits is not None:
mask = torch.zeros(len(y), dtype=torch.bool)
for d in digits:
mask |= (y == d)
X = X[mask]
y = y[mask]
# Remap labels to 0, 1... for binary tasks
if len(digits) == 2:
y = torch.where(y == digits[0], torch.tensor(0), torch.tensor(1))
# Add channel dimension if not flattened (B, 1, 28, 28)
if not flatten:
X = X.unsqueeze(1)
else:
X = X.view(X.size(0), -1)
return X, y
def load_models(dataset_name="mnist"):
"""
Loads pre-trained SVD transformer and CNN model for a specific dataset.
Returns (svd, cnn). Either can be None if the file is missing.
"""
svd_path = config.SVD_MODEL_PATH if dataset_name == "mnist" else config.FASHION_SVD_PATH
cnn_path = config.CNN_MODEL_PATH if dataset_name == "mnist" else config.FASHION_CNN_PATH
svd, cnn = None, None
if os.path.exists(svd_path):
with open(svd_path, "rb") as f:
svd = pickle.load(f)
else:
print(f"Note: SVD model for {dataset_name} not found at {svd_path}")
if os.path.exists(cnn_path):
cnn = SimpleCNN()
cnn.load_state_dict(torch.load(cnn_path, map_location="cpu"))
cnn.eval()
else:
print(f"Note: CNN model for {dataset_name} not found at {cnn_path}")
return svd, cnn
# --- Backward Compatibility Aliases ---
load_data = load_data_split
def preprocess_digit(img):
"""
Original preprocessing logic used by the Streamlit app.
Crops, resizes (20x20), and pads to 28x28.
"""
if isinstance(img, torch.Tensor):
img = img.numpy().astype(np.uint8)
# 1. Threshold & Find Bounding Box
_, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
coords = cv2.findNonZero(thresh)
if coords is None:
return torch.zeros((28, 28))
x, y, w, h = cv2.boundingRect(coords)
img_crop = img[y:y+h, x:x+w]
# 2. Resize to fit 20px
if w > h:
new_w = 20
new_h = int(h * (20 / w))
else:
new_h = 20
new_w = int(w * (20 / h))
if new_w == 0 or new_h == 0:
return torch.zeros((28, 28))
img_resize = cv2.resize(img_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
# 3. Center in 28x28
final_img = np.zeros((28, 28), dtype=np.uint8)
pad_y = (28 - new_h) // 2
pad_x = (28 - new_w) // 2
final_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = img_resize
# 4. Normalize
return torch.tensor(final_img).float() / 255.0