ml_assignment_1 / 101234444_aml_assignment_1.py
ImanAmran's picture
Update 101234444_aml_assignment_1.py
b1affcb
raw
history blame
2.39 kB
# -*- coding: utf-8 -*-
"""101234444_aml_assignment_1.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1GBU5kKqfnliMP-lElZZ4VgVgcsyqy1wQ
"""
import requests
import tensorflow as tf
import PIL.Image
import numpy as np
import json
import gradio as gr
# Download the final_model.h5 file
url_model = "https://huggingface.co/ImanAmran/ml_assignment_1/resolve/main/final_model.h5"
response_model = requests.get(url_model)
with open("final_model.h5", "wb") as f_model:
f_model.write(response_model.content)
# Download the class_indices.json file
url_indices = "https://huggingface.co/ImanAmran/ml_assignment_1/resolve/main/class_indices.json"
response_indices = requests.get(url_indices)
class_indices = response_indices.json() # Parse the JSON response
# Load the model
model = tf.keras.models.load_model("final_model.h5")
# Reverse the key-value pairs in the class_indices dictionary
index_to_class = {v: k for k, v in class_indices.items()}
def classify_image(image: PIL.Image.Image):
try:
# Ensure the input is a PIL Image, resize it, and then convert it to a NumPy array
if not isinstance(image, PIL.Image.Image):
image = PIL.Image.fromarray(image)
image_resized = image.resize((375, 375))
image_array = np.array(image_resized)
image_array = np.expand_dims(image_array, axis=0) # Add a batch dimension
# Preprocess the image array in the same way as your manual prediction function
img_preprocessed = tf.keras.applications.resnet50.preprocess_input(image_array)
# Perform inference
predictions = model.predict(img_preprocessed)
predicted_class_idx = np.argmax(predictions) # Get the predicted class index
# Map index to label using index_to_class
predicted_class_label = index_to_class[predicted_class_idx]
return predicted_class_label
except Exception as e:
return str(e) # Return the exception message to help identify the issue
# Create a Gradio Interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.components.Image(),
outputs=gr.components.Textbox(),
live=True, # This line is optional, it enables real-time feedback but may slow down performance
share=True # This line allows Gradio to be run in this Colab notebook
)
#iface.launch()