Ahmed235's picture
Create app.py
0b188bd verified
raw
history blame
2.23 kB
import json
from PIL import Image
import numpy as np
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from tensorflow.keras.models import load_model
import ipywidgets as widgets
from IPython.display import display
model_path = 'final_teath_classifier.h5'
model = tf.keras.models.load_model(model_path)
# Load the model from Hugging Face model hub
def preprocess_image(image: Image.Image) -> np.ndarray:
# Resize the image to match input size
image = image.resize((256, 256))
# Convert image to array and preprocess input
img_array = np.array(image) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_image(image_path):
img = Image.open(image_path)
# Preprocess the image
img_array = preprocess_image(img)
# Convert image array to string using base64 encoding (for text-based models)
#inputs = tokenizer.encode(img_array, return_tensors="tf")
# Make prediction
outputs = model(img_array)
predictions = tf.nn.softmax(outputs.logits, axis=-1)
predicted_class = np.argmax(predictions)
if predicted_class == 0:
predict_label = "Clean"
else:
predict_label = "Carries"
return predict_label, predictions.numpy().flatten()
# Create a file uploader widget
uploader = widgets.FileUpload(accept="image/*", multiple=False)
# Display the file uploader widget
display(uploader)
# Define a callback function to handle the uploaded image
def on_upload(change):
# Get the uploaded image file
image_file = list(uploader.value.values())[0]["content"]
# Save the image to a temporary file
with open("temp_image.jpg", "wb") as f:
f.write(image_file)
# Get predictions for the uploaded image
predict_label, logits = predict_image("temp_image.jpg")
# Create a JSON object with the predictions
predictions_json = {
"predicted_class": predict_label,
"evaluations": [f"{logit*100:.4f}%" for logit in logits]
}
# Print the JSON object
print(json.dumps(predictions_json, indent=4))
# Set the callback function for when a file is uploaded
uploader.observe(on_upload, names="value")