Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as T | |
| import torchvision | |
| import numpy as np | |
| import os | |
| import pickle | |
| import ssl | |
| from src.hybrid_model import SimpleCNN | |
| from src import config | |
| import cv2 | |
| def load_data_split(dataset_name="mnist", train=True, digits=None, flatten=False): | |
| """ | |
| Unified entry point for data loading. | |
| Supports: MNIST, Fashion-MNIST, and custom digit filtering (e.g., [3, 8]). | |
| """ | |
| # Bypass SSL verification issues for dataset downloads | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| transform = T.Compose([T.ToTensor()]) | |
| if dataset_name.lower() == "mnist": | |
| dataset = torchvision.datasets.MNIST(config.DATA_DIR, train=train, download=True, transform=transform) | |
| elif dataset_name.lower() == "fashion": | |
| dataset = torchvision.datasets.FashionMNIST(config.DATA_DIR, train=train, download=True, transform=transform) | |
| else: | |
| raise ValueError(f"Unknown dataset: {dataset_name}") | |
| X = dataset.data.float() / 255.0 | |
| y = dataset.targets | |
| # Filter for specific digits if requested (e.g., [3, 8] for binary analysis) | |
| if digits is not None: | |
| mask = torch.zeros(len(y), dtype=torch.bool) | |
| for d in digits: | |
| mask |= (y == d) | |
| X = X[mask] | |
| y = y[mask] | |
| # Remap labels to 0, 1... for binary tasks | |
| if len(digits) == 2: | |
| y = torch.where(y == digits[0], torch.tensor(0), torch.tensor(1)) | |
| # Add channel dimension if not flattened (B, 1, 28, 28) | |
| if not flatten: | |
| X = X.unsqueeze(1) | |
| else: | |
| X = X.view(X.size(0), -1) | |
| return X, y | |
| def load_models(dataset_name="mnist"): | |
| """ | |
| Loads pre-trained SVD transformer and CNN model for a specific dataset. | |
| Returns (svd, cnn). Either can be None if the file is missing. | |
| """ | |
| svd_path = config.SVD_MODEL_PATH if dataset_name == "mnist" else config.FASHION_SVD_PATH | |
| cnn_path = config.CNN_MODEL_PATH if dataset_name == "mnist" else config.FASHION_CNN_PATH | |
| svd, cnn = None, None | |
| if os.path.exists(svd_path): | |
| with open(svd_path, "rb") as f: | |
| svd = pickle.load(f) | |
| else: | |
| print(f"Note: SVD model for {dataset_name} not found at {svd_path}") | |
| if os.path.exists(cnn_path): | |
| cnn = SimpleCNN() | |
| cnn.load_state_dict(torch.load(cnn_path, map_location="cpu")) | |
| cnn.eval() | |
| else: | |
| print(f"Note: CNN model for {dataset_name} not found at {cnn_path}") | |
| return svd, cnn | |
| # --- Backward Compatibility Aliases --- | |
| load_data = load_data_split | |
| def preprocess_digit(img): | |
| """ | |
| Original preprocessing logic used by the Streamlit app. | |
| Crops, resizes (20x20), and pads to 28x28. | |
| """ | |
| if isinstance(img, torch.Tensor): | |
| img = img.numpy().astype(np.uint8) | |
| # 1. Threshold & Find Bounding Box | |
| _, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) | |
| coords = cv2.findNonZero(thresh) | |
| if coords is None: | |
| return torch.zeros((28, 28)) | |
| x, y, w, h = cv2.boundingRect(coords) | |
| img_crop = img[y:y+h, x:x+w] | |
| # 2. Resize to fit 20px | |
| if w > h: | |
| new_w = 20 | |
| new_h = int(h * (20 / w)) | |
| else: | |
| new_h = 20 | |
| new_w = int(w * (20 / h)) | |
| if new_w == 0 or new_h == 0: | |
| return torch.zeros((28, 28)) | |
| img_resize = cv2.resize(img_crop, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| # 3. Center in 28x28 | |
| final_img = np.zeros((28, 28), dtype=np.uint8) | |
| pad_y = (28 - new_h) // 2 | |
| pad_x = (28 - new_w) // 2 | |
| final_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = img_resize | |
| # 4. Normalize | |
| return torch.tensor(final_img).float() / 255.0 | |