zihaoz96's picture
Update app.py
3f221c4
raw history blame
No virus
1.84 kB
import gradio as gr
import numpy as np
from PIL import Image
from tensorflow.keras import models
from tensorflow.keras.preprocessing.image import load_img
import tensorflow as tf
from hugsvision.inference.TorchVisionClassifierInference import TorchVisionClassifierInference
models_name = [
"VGG16",
"DenseNet121",
"DenseNet"
]
# open categories.txt in read mode
categories = open("categories.txt", "r")
labels = categories.readline().split(";")
# create a radio
radio = gr.inputs.Radio(models_name, default="DenseNet121", type="value")
def predict_image(image, model_name):
# model create by keras
if model_name == "DenseNet":
image = np.array(image) / 255
image = np.expand_dims(image, axis=0)
model = model = models.load_model("./models/" + model_name + "/model.h5")
pred = model.predict(image)
pred = dict((labels[i], "%.2f" % pred[0][i]) for i in range(len(labels)))
# model create by HugsVision
else:
image = Image.fromarray(np.uint8(image)).convert('RGB')
classifier = TorchVisionClassifierInference(
model_path = "./models/" + model_name
)
pred = classifier.predict_image(img=image, return_str=False)
for key in pred.keys():
pred[key] = pred[key]/100
print(pred)
return pred
image = gr.inputs.Image(shape=(300, 300), label="Upload Your Image Here")
label = gr.outputs.Label(num_top_classes=len(labels))
samples = [["samples/" + p + ".jpg"] for p in labels]
interface = gr.Interface(
fn=predict_image,
inputs=[image, radio],
outputs=label,
capture_session=True,
allow_flagging=False,
title="🦈 Shark image classifier",
description="Made with HugsVision & ❤️",
examples=samples,
theme=None
)
interface.launch()