import streamlit as st import numpy as np import tensorflow as tf from PIL import Image from ultralytics import YOLO import cv2 # Load models st.sidebar.title("Settings") classification_model = tf.keras.models.load_model('./models.h5') detection_model = YOLO('./best.pt') # Load labels labels = [] with open("labels.txt") as f: labels = [line.strip() for line in f] # Function to classify image def classify_image(img): img = img.resize((224, 224)) # Resize image img_array = np.array(img) img_array = img_array.reshape((-1, 224, 224, 3)) img_array = tf.keras.applications.efficientnet.preprocess_input(img_array) prediction = classification_model.predict(img_array).flatten() confidences = {labels[i]: float(prediction[i]) for i in range(90)} return confidences # Function to detect animals and classify them def animal_detect_and_classify(img, detect_results): img = np.array(img) combined_results = [] for result in detect_results: for box in result.boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) detect_img = img[y1:y2, x1:x2] detect_img = cv2.resize(detect_img, (224, 224)) inp_array = np.array(detect_img).reshape((-1, 224, 224, 3)) inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array) prediction = classification_model.predict(inp_array) confidences_classification = {labels[i]: float(prediction[0][i]) for i in range(90)} predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= 0.66 else "animal" for pred in prediction] combined_results.append(((x1, y1, x2, y2), predicted_labels)) return combined_results # Function to generate color for bounding boxes def generate_color(class_name): color_hash = abs(hash(class_name)) % 16777216 R = color_hash // (256 * 256) G = (color_hash // 256) % 256 B = color_hash % 256 return (R, G, B) # Function to draw bounding boxes def plot_detected_rectangles(image, detections): img_with_rectangles = np.array(image).copy() for rectangle, class_names in detections: if class_names[0] == "unknown": continue x1, y1, x2, y2 = rectangle color = generate_color(class_names[0]) cv2.rectangle(img_with_rectangles, (x1, y1), (x2, y2), color, 2) for i, class_name in enumerate(class_names): cv2.putText(img_with_rectangles, class_name, (x1, y1 - 10 - i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return Image.fromarray(img_with_rectangles) # Function to run object detection def detection_image(img, conf_threshold, iou_threshold): results = detection_model.predict( source=img, conf=conf_threshold, iou=iou_threshold, imgsz=640, ) combined_results = animal_detect_and_classify(img, results) plotted_image = plot_detected_rectangles(img, combined_results) return plotted_image # Streamlit UI st.title("Animal Image Processing") tab1, tab2 = st.tabs(["Image Classification", "Object Detection"]) with tab1: st.header("Image Classification") uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) predictions = classify_image(image) sorted_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:3] st.subheader("Top Predictions:") for label, confidence in sorted_preds: st.write(f"**{label}**: {confidence*100:.2f}%") with tab2: st.header("Object Detection") uploaded_file_detect = st.file_uploader("Upload an image for object detection...", type=["jpg", "jpeg", "png"]) conf_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.25) iou_threshold = st.slider("IoU Threshold", 0.0, 1.0, 0.45) if uploaded_file_detect is not None: image = Image.open(uploaded_file_detect) st.image(image, caption="Uploaded Image", use_container_width=True) detected_image = detection_image(image, conf_threshold, iou_threshold) st.image(detected_image, caption="Detected Objects", use_container_width=True)