harshaUwm163
updated the names of the example files to .jpg instead of .jpeg
b3cf2db
raw history blame
No virus
2.17 kB
import numpy as np
import tensorflow as tf
import gradio as gr
from huggingface_hub import from_pretrained_keras
import cv2
import matplotlib.pyplot as plt
model = from_pretrained_keras("harsha163/CutMix_data_augmentation_for_image_classification")
# functions for inference
IMG_SIZE = 32
class_names = [
"Airplane",
"Automobile",
"Bird",
"Cat",
"Deer",
"Dog",
"Frog",
"Horse",
"Ship",
"Truck",
]
# resize the image and it to a float between 0,1
def preprocess_image(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf.image.convert_image_dtype(image, tf.float32) / 255.0
return image, label
def read_image(image):
image = tf.convert_to_tensor(image)
image.set_shape([None, None, 3])
print('$$$$$$$$$$$$$$$$$$$$$ in read image $$$$$$$$$$$$$$$$$$$$$$')
print(image.shape)
plt.imshow(image)
plt.show()
# image = tf.image.resize(images=image, size=[IMG_SIZE, IMG_SIZE])
# image = image / 127.5 - 1
image, _ = preprocess_image(image, 1) # 1 here is a temporary label
return image
def infer(input_image):
print('#$$$$$$$$$$$$$$$$$$$$$$$$$ IN INFER $$$$$$$$$$$$$$$$$$$$$$$')
image_tensor = read_image(input_image)
print(image_tensor.shape)
predictions = model.predict(np.expand_dims((image_tensor), axis=0))
predictions = np.squeeze(predictions)
predictions = np.argmax(predictions) # , axis=2
predicted_label = class_names[predictions.item()]
return str(predicted_label)
# get the inputs
input = gr.inputs.Image(shape=(IMG_SIZE, IMG_SIZE))
# the app outputs two segmented images
output = [gr.outputs.Label()]
# it's good practice to pass examples, description and a title to guide users
examples = [["./content/examples/Frog.jpg"], ["./content/examples/Truck.jpg"]]
title = "Image classification"
description = "Upload an image or select from examples to classify it"
gr_interface = gr.Interface(infer, input, output, examples=examples, allow_flagging=False, analytics_enabled=False, title=title, description=description).launch(enable_queue=True, debug=True)
gr_interface.launch()