FP-KCV / appv2.py
xcurv's picture
Upload appv2.py
37bed94 verified
import streamlit as st
import numpy as np
import tensorflow as tf
from PIL import Image
from ultralytics import YOLO
import cv2
# Load models
st.sidebar.title("Settings")
classification_model = tf.keras.models.load_model('./models.h5')
detection_model = YOLO('./best.pt')
# Load labels
labels = []
with open("labels.txt") as f:
labels = [line.strip() for line in f]
# Function to classify image
def classify_image(img):
img = img.resize((224, 224)) # Resize image
img_array = np.array(img)
img_array = img_array.reshape((-1, 224, 224, 3))
img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
prediction = classification_model.predict(img_array).flatten()
confidences = {labels[i]: float(prediction[i]) for i in range(90)}
return confidences
# Function to detect animals and classify them
def animal_detect_and_classify(img, detect_results):
img = np.array(img)
combined_results = []
for result in detect_results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
detect_img = img[y1:y2, x1:x2]
detect_img = cv2.resize(detect_img, (224, 224))
inp_array = np.array(detect_img).reshape((-1, 224, 224, 3))
inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array)
prediction = classification_model.predict(inp_array)
confidences_classification = {labels[i]: float(prediction[0][i]) for i in range(90)}
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= 0.66 else "animal" for pred in prediction]
combined_results.append(((x1, y1, x2, y2), predicted_labels))
return combined_results
# Function to generate color for bounding boxes
def generate_color(class_name):
color_hash = abs(hash(class_name)) % 16777216
R = color_hash // (256 * 256)
G = (color_hash // 256) % 256
B = color_hash % 256
return (R, G, B)
# Function to draw bounding boxes
def plot_detected_rectangles(image, detections):
img_with_rectangles = np.array(image).copy()
for rectangle, class_names in detections:
if class_names[0] == "unknown":
continue
x1, y1, x2, y2 = rectangle
color = generate_color(class_names[0])
cv2.rectangle(img_with_rectangles, (x1, y1), (x2, y2), color, 2)
for i, class_name in enumerate(class_names):
cv2.putText(img_with_rectangles, class_name, (x1, y1 - 10 - i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return Image.fromarray(img_with_rectangles)
# Function to run object detection
def detection_image(img, conf_threshold, iou_threshold):
results = detection_model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
imgsz=640,
)
combined_results = animal_detect_and_classify(img, results)
plotted_image = plot_detected_rectangles(img, combined_results)
return plotted_image
# Streamlit UI
st.title("Animal Image Processing")
tab1, tab2 = st.tabs(["Image Classification", "Object Detection"])
with tab1:
st.header("Image Classification")
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
predictions = classify_image(image)
sorted_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:3]
st.subheader("Top Predictions:")
for label, confidence in sorted_preds:
st.write(f"**{label}**: {confidence*100:.2f}%")
with tab2:
st.header("Object Detection")
uploaded_file_detect = st.file_uploader("Upload an image for object detection...", type=["jpg", "jpeg", "png"])
conf_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.25)
iou_threshold = st.slider("IoU Threshold", 0.0, 1.0, 0.45)
if uploaded_file_detect is not None:
image = Image.open(uploaded_file_detect)
st.image(image, caption="Uploaded Image", use_container_width=True)
detected_image = detection_image(image, conf_threshold, iou_threshold)
st.image(detected_image, caption="Detected Objects", use_container_width=True)