ohio_space / app.py
sowbaranika13's picture
Update app.py
4cc329e verified
raw
history blame
3.13 kB
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array,load_img
import numpy as np
from PIL import Image
import os
# Load your model and tokenizer
labels = {
'class': ['amphibia', 'aves', 'invertebrates', 'lacertilia', 'mammalia', 'serpentes', 'testudines'],
'serpentes': ["Butler's Gartersnake", "Dekay's Brownsnake", 'Eastern Gartersnake', 'Eastern Hog-nosed snake', 'Eastern Massasauga', 'Eastern Milksnake', 'Eastern Racer Snake', 'Eastern Ribbonsnake', 'Gray Ratsnake', "Kirtland's Snake", 'Northern Watersnake', 'Plains Gartersnake', 'Red-bellied Snake', 'Smooth Greensnake'],
'mammalia': ['American Mink', 'Brown Rat', 'Eastern Chipmunk', 'Eastern Cottontail', 'Long-tailed Weasel', 'Masked Shrew', 'Meadow Jumping Mouse', 'Meadow Vole', 'N. Short-tailed Shrew', 'Raccoon', 'Star-nosed mole', 'Striped Skunk', 'Virginia Opossum', 'White-footed Mouse', 'Woodchuck', 'Woodland Jumping Mouse'],
'aves': ['Common Yellowthroat', 'Gray Catbird', 'Indigo Bunting', 'Northern House Wren', 'Song Sparrow', 'Sora'],
'amphibia': ['American Bullfrog', 'American Toad', 'Green Frog', 'Northern Leopard Frog']
}
hierarchical_models = {}
model_path = r"inceptionv3_class.h5"
hierarchical_models['class'] = load_model(model_path)
def load_and_preprocess_image(image, target_size=(224, 224)):
image = image.resize(target_size)
img_array = img_to_array(image)
img_array = np.expand_dims(img_array, axis=0)
img_array = tf.keras.applications.mobilenet_v2.preprocess_input(img_array)
return img_array
def predict(image):
results = {}
image_array = load_and_preprocess_image(image)
# Predict class level
class_preds = hierarchical_models['class'].predict(image_array)
print(class_preds)
class_idx = np.argmax(class_preds)
print(class_idx)
class_label = labels['class'][class_idx]
class_confidence = class_preds[0][class_idx]
class_level = f"{class_label} ({class_confidence*100:.2f}%)"
# Predict species level
hierarchical_models[class_label] = load_model(f"inceptionv3_{class_label}.h5")
species_preds = hierarchical_models[class_label].predict(image_array)
species_idx = np.argmax(species_preds)
species_label = labels[class_label][species_idx]
species_confidence = species_preds[0][species_idx]
species_level = f"{species_label} ({species_confidence*100:.2f}%)"
return class_level,species_level
# Sample images (you can add paths to images here)
# sample_images = [
# ("Sample Amphibia", "path/to/amphibia.jpg"),
# ("Sample Aves", "path/to/aves.jpg"),
# ("Sample Mammalia", "path/to/mammalia.jpg"),
# # Add more sample images as needed
# ]
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(label="class_label"), gr.Label(label="species_label")],
# examples=sample_images,
title="Image Classification",
description="Upload an image to classify it into species and class level.",
)
# Launch the interface
iface.launch()