Kanji_ETL9G / README.md
LT8's picture
Update README.md
f9d3c84
|
raw
history blame
1.68 kB
metadata
license: creativeml-openrail-m

Based on: ETL9G 607200 samples 3036 classes (hiragana and kanji) 200 samples each class record_length: 8199 bytes image_width: 64px image_height: 64px

I was testing a few more samples locally with the below. Note the results of the model are encoded.


from PIL import Image
import numpy as np
import torch

# Define the preprocessing function
def preprocess_image(image_path):
    image = Image.open(image_path).convert('L')
    resized_image = image.resize((64, 64))
    image_array = np.array(resized_image) / 255.0
    reshaped_image = image_array.reshape(1, -1)
    return reshaped_image

# Function to predict a label for an image using your PyTorch model
def predict_label(image_path, model, device):
    # Convert the preprocessed image to torch tensor and send to the device
    processed_image = torch.tensor(preprocess_image(image_path), dtype=torch.float32).to(device)
    
    # Predict the label using the model
    with torch.no_grad():
        outputs = model(processed_image)
        _, predicted_class = torch.max(outputs.data, 1)
    
    return predicted_class.item()

# Create the reverse dictionary for decoding
index_to_label = {index: label for label, index in label_to_index.items()}

# Test using a sample image
sample_image_path = ["example.png", "example.png", "example.png", "example.png", "example.png"]

for sample in sample_image_path: 
    predicted_encoded_label = predict_label(sample, model, device)

    # Decode the predicted label using the reversed dictionary
    decoded_label = index_to_label[predicted_encoded_label]
    print(f"The model predicts the image label as: {decoded_label}")