import gradio as gr import numpy as np import tensorflow as tf def softmax(vector): e = np.exp(vector) return e / e.sum() def image_to_output (input_img): gr_img = [] gr_img.append(input_img) img2 = tf.image.resize(tf.cast(gr_img, tf.float32)/255. , [224, 224]) # print(img2) x_test = np.asarray(img2) prediction = model2.predict(x_test,batch_size=1).flatten() prediction = softmax(prediction) confidences = {labels[i]: float(prediction[i]) for i in range(102)} # confidences = {labels[i]:float(top[i]) for i in range(num_predictions)} return confidences # Download the model checkpoint import os import requests pretrained_repo = 'pretrained_model' model_repo_link = 'https://huggingface.co/qmjnh/flowerClassification_2/resolve/main/' for item in [ 'variables.data-00000-of-00001', 'variables.index', 'keras_metadata.pb', 'saved_model.pb', ]: params = requests.get(model_repo_link+item) if item.startswith('variables'): output_file = os.path.join(pretrained_repo, 'variables', item) else: output_file = os.path.join(pretrained_repo, item) if not os.path.exists(os.path.dirname(output_file)): os.makedirs(os.path.dirname(output_file)) with open(output_file, 'wb') as f: print(f'Downloading from {model_repo_link+item} to {output_file}') f.write(params.content) # Load the model model2=tf.keras.models.load_model(pretrained_repo) # Read the labels with open('flower_names.txt') as f: labels = f.readlines() # Run gradio from gradio.components import Image as gradio_image from gradio.components import Label as gradio_label UI=gr.Interface(fn=image_to_output, inputs=gradio_image(shape=(224,224)), outputs=gradio_label(num_top_classes=5), interpretation="default" ) description = "This model was trained to recognize 102 types of flowers. For the model to work with high accuracy, refer to the trained flowers [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt)" UI=gr.Interface(fn=image_to_output, inputs=gradio_image(shape=(224,224)), outputs=gradio_label(num_top_classes=5), interpretation="default", description=description, title="Flower Classifier", article="*built by qmjnh*" ) description = "This model was trained to recognize 102 types flowers. For the model to work with high-accuracy, refer to the trained flowers [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt)" article1="This is an AI model trained to predict the name of the flower in the input picture, given that the flower is one of 102 types of flowers (flowers list can be found [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt) ). To try out the model, simply drop/upload a picture into the '*input box*' and press 'Submit'. The predictions will show up in the '*output box*'\n" article2="\n *built by qmjnh*" UI=gr.Interface(fn=image_to_output, inputs=gradio_image(shape=(224,224)), outputs=gradio_label(num_top_classes=5), interpretation="none", description=description, title="Flower Classifier", article= article1 + article2 ) UI.launch(share=True)