qmjnh's picture
Update app.py
6456b67
raw history blame
No virus
3.49 kB
import gradio as gr
import numpy as np
import tensorflow as tf
def softmax(vector):
e = np.exp(vector)
return e / e.sum()
def image_to_output (input_img):
gr_img = []
gr_img.append(input_img)
img2 = tf.image.resize(tf.cast(gr_img, tf.float32)/255. , [224, 224])
# print(img2)
x_test = np.asarray(img2)
prediction = model2.predict(x_test,batch_size=1).flatten()
prediction = softmax(prediction)
confidences = {labels[i]: float(prediction[i]) for i in range(102)}
# confidences = {labels[i]:float(top[i]) for i in range(num_predictions)}
return confidences
# Download the model checkpoint
import os
import requests
pretrained_repo = 'pretrained_model'
model_repo_link = 'https://huggingface.co/qmjnh/flowerClassification_2/resolve/main/'
for item in [
'variables.data-00000-of-00001',
'variables.index',
'keras_metadata.pb',
'saved_model.pb',
]:
params = requests.get(model_repo_link+item)
if item.startswith('variables'):
output_file = os.path.join(pretrained_repo, 'variables', item)
else:
output_file = os.path.join(pretrained_repo, item)
if not os.path.exists(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file))
with open(output_file, 'wb') as f:
print(f'Downloading from {model_repo_link+item} to {output_file}')
f.write(params.content)
# Load the model
model2=tf.keras.models.load_model(pretrained_repo)
# Read the labels
with open('flower_names.txt') as f:
labels = f.readlines()
# Run gradio
from gradio.components import Image as gradio_image
from gradio.components import Label as gradio_label
UI=gr.Interface(fn=image_to_output,
inputs=gradio_image(shape=(224,224)),
outputs=gradio_label(num_top_classes=5),
interpretation="default"
)
description = "This model was trained to recognize 102 types of flowers. For the model to work with high accuracy, refer to the trained flowers [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt)"
UI=gr.Interface(fn=image_to_output,
inputs=gradio_image(shape=(224,224)),
outputs=gradio_label(num_top_classes=5),
interpretation="default",
description=description,
title="Flower Classifier",
article="*built by qmjnh*"
)
description = "This model was trained to recognize 102 types flowers. For the model to work with high-accuracy, refer to the trained flowers [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt)"
article1="This is an AI model trained to predict the name of the flower in the input picture, given that the flower is one of 102 types of flowers (flowers list can be found [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt) ). To try out the model, simply drop/upload a picture into the '*input box*' and press 'Submit'. The predictions will show up in the '*output box*'\n"
article2="\n *built by qmjnh*"
UI=gr.Interface(fn=image_to_output,
inputs=gradio_image(shape=(224,224)),
outputs=gradio_label(num_top_classes=5),
interpretation="none",
description=description,
title="Flower Classifier",
article= article1 + article2
)
UI.launch(share=True)