0xnu
/

Image Classification
Keras
vision
Edit model card

The MNIST OCR (Optical Character Recognition) model is a deep learning model trained to recognise and classify handwritten digits from 0 to 9. This model is trained on the MNIST dataset, which consists of 60,000 small square 28×28 pixel grayscale images of handwritten single digits, making it highly accurate for recognising written, isolated digits in a similar style to those found in the training set.

Training History

Install Packages

pip install numpy opencv-python requests pillow transformers tensorflow

Usage

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import numpy as np
import cv2
import requests
from PIL import Image
from io import BytesIO
from typing import List, Optional
from huggingface_hub import hf_hub_download
import tensorflow as tf
import pickle

class ImageTokenizer:
    def __init__(self):
        self.unique_pixels = set()
        self.pixel_to_token = {}
        self.token_to_pixel = {}

    def fit(self, images):
        for image in images:
            self.unique_pixels.update(np.unique(image))
        self.pixel_to_token = {pixel: i for i, pixel in enumerate(sorted(self.unique_pixels))}
        self.token_to_pixel = {i: pixel for pixel, i in self.pixel_to_token.items()}

    def tokenize(self, images):
        return np.vectorize(self.pixel_to_token.get)(images)

    def detokenize(self, tokens):
        return np.vectorize(self.token_to_pixel.get)(tokens)

class MNISTPredictor:
    def __init__(self, model_name):
        # Download the model and tokenizer files
        model_path = hf_hub_download(repo_id=model_name, filename="mnist_model.keras")
        tokenizer_path = hf_hub_download(repo_id=model_name, filename="mnist_tokenizer.pkl")

        # Load the model and tokenizer
        self.model = keras.models.load_model(model_path)
        with open(tokenizer_path, 'rb') as tokenizer_file:
            self.tokenizer = pickle.load(tokenizer_file)

    def extract_features(self, image: Image.Image) -> List[np.ndarray]:
        """Extract features from the image for multiple digits."""
        # Convert to grayscale
        gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)

        # Apply Gaussian blur
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)

        # Apply adaptive thresholding
        thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)

        # Find contours
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        digit_images = []
        for contour in contours:
            # Filter small contours
            if cv2.contourArea(contour) > 50:  # Adjust this threshold as needed
                x, y, w, h = cv2.boundingRect(contour)
                roi = thresh[y:y+h, x:x+w]
                resized = cv2.resize(roi, (28, 28), interpolation=cv2.INTER_AREA)
                digit_images.append(resized.reshape((28, 28, 1)).astype('float32') / 255)

        return digit_images

    def predict(self, image: Image.Image) -> Optional[List[int]]:
        """Predict digits in the image."""
        try:
            digit_images = self.extract_features(image)
            tokenized_images = [self.tokenizer.tokenize(img) for img in digit_images]
            predictions = self.model.predict(np.array(tokenized_images), verbose=0)
            return np.argmax(predictions, axis=1).tolist()
        except Exception as e:
            print(f"Error during prediction: {e}")
            return None

def download_image(url: str) -> Optional[Image.Image]:
    """Download an image from a URL."""
    try:
        response = requests.get(url)
        response.raise_for_status()
        return Image.open(BytesIO(response.content))
    except Exception as e:
        print(f"Error downloading image: {e}")
        return None

def save_predictions_to_file(predictions: List[int], output_path: str) -> None:
    """Save predictions to a text file."""
    try:
        with open(output_path, 'w') as f:
            f.write(f"Predicted digits are: {', '.join(map(str, predictions))}\n")
    except Exception as e:
        print(f"Error saving predictions to file: {e}")

def main(image_url: str, model_name: str, output_path: str) -> None:
    try:
        predictor = MNISTPredictor(model_name)

        # Download image
        image = download_image(image_url)
        if image is None:
            raise Exception("Failed to download image")

        print(f"Image downloaded successfully.")

        # Predict digits
        digits = predictor.predict(image)
        if digits is not None:
            print(f"Predicted digits are: {digits}")

            # Save predictions to file
            save_predictions_to_file(digits, output_path)
            print(f"Predictions saved to {output_path}")
        else:
            print("Failed to predict digits.")
    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    image_url = "https://miro.medium.com/v2/resize:fit:720/format:webp/1*w7pBsjI3t3ZP-4Gdog-JdQ.png"
    model_name = "0xnu/mnist-ocr"
    output_path = "predictions.txt"

    main(image_url, model_name, output_path)

Copyright

(c) 2024 Finbarrs Oketunji. All Rights Reserved.

Downloads last month
152
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train 0xnu/mnist-ocr