plants_disease / app.py
RandomCatLover's picture
fuuuuuuu
2e6883a
raw
history blame
1.39 kB
# %%
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
import os
model_folder = 'model'
destination = model_folder
repo_url = "https://huggingface.co/RandomCatLover/plants_disease"
if not os.path.exists(destination):
import subprocess
#repo_url = os.getenv("GIT_CORE")
command = f'git clone {repo_url} {destination}'
try:
subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env)
print('Repository cloned successfully.')
except subprocess.CalledProcessError as e:
print(f'Error cloning repository: {e.output.decode()}')
# %%
with open(f'{model_folder}/labels.txt', 'r') as f:
labels = f.read().split('\n')
# model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5')
model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5')
# %%
def classify_image(inp):
inp = cv2.resize(inp, (224,224,))
inp = inp.reshape((-1, 224, 224, 3))
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
prediction = model.predict(inp).flatten()
print(prediction)
confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}
return confidences
gr.Interface(fn=classify_image,
inputs=gr.Image(shape=(224, 224)),
outputs=gr.Label(num_top_classes=3),
examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg"]).launch()