streamlit_app / app.py
MuhammmadRizwanRizwan's picture
Update app.py
52ed7df verified
import streamlit as st
import numpy as np
import cv2
import warnings
import os
# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Try importing TensorFlow
try:
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
except ImportError:
st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
# Try importing PyTorch and Detectron2
try:
import torch
import detectron2
except ImportError:
with st.spinner("Installing PyTorch and Detectron2..."):
os.system("pip install torch torchvision")
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
import torch
import detectron2
import streamlit as st
import numpy as np
import cv2
import torch
import os
from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
# Suppress warnings
import warnings
import tensorflow as tf
warnings.filterwarnings("ignore")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@st.cache_resource
def load_models():
model_name = load_model('name_model_inception.h5')
model_quality = load_model('type_model_inception.h5')
return model_name, model_quality
model_name, model_quality = load_models()
# Detectron2 setup
@st.cache_resource
def load_detectron_model(fruit_name):
cfg = get_cfg()
config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
cfg.merge_from_file(config_path)
model_path = os.path.join(f"{fruit_name}_model.pth")
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.DEVICE = 'cpu'
predictor = DefaultPredictor(cfg)
return predictor, cfg
# Labels
label_map_name = {
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
10: "tomato"
}
label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
def predict_fruit(img):
# Preprocess image
img = Image.fromarray(img.astype('uint8'), 'RGB')
img = img.resize((224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = x / 255.0
# Predict
pred_name = model_name.predict(x)
pred_quality = model_quality.predict(x)
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
return predicted_name, predicted_quality, img
def main():
st.title("Automated Fruits Monitoring System")
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
if st.button("Analyze"):
predicted_name, predicted_quality, img = predict_fruit(np.array(image))
st.write(f"Fruits Type Detection: {predicted_name}")
st.write(f"Fruits Quality Classification: {predicted_quality}")
if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
st.write("Segmentation of Defective Region:")
try:
predictor, cfg = load_detectron_model(predicted_name)
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
except Exception as e:
st.error(f"Error in damage detection: {str(e)}")
else:
st.write("No damage detection performed for this fruit or quality level.")
if __name__ == "__main__":
main()
# import streamlit as st
# import numpy as np
# import cv2
# import warnings
# import os
# from pathlib import Path
# from PIL import Image
# import tensorflow as tf
# from tensorflow.keras.models import load_model
# from tensorflow.keras.preprocessing import image
# from detectron2.engine import DefaultPredictor
# from detectron2.config import get_cfg
# from detectron2.utils.visualizer import Visualizer
# from detectron2.data import MetadataCatalog
# # Suppress warnings
# warnings.filterwarnings("ignore")
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# # Configuration
# MODEL_CONFIG = {
# 'name_model': 'name_model_inception.h5',
# 'quality_model': 'type_model_inception.h5',
# 'input_size': (224, 224),
# 'score_threshold': 0.5
# }
# LABEL_MAPS = {
# 'name': {
# 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
# 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
# 10: "tomato"
# },
# 'quality': {0: "Good", 1: "Mild", 2: "Rotten"}
# }
# @st.cache_resource
# def load_classification_models():
# """Load and cache the classification models."""
# try:
# model_name = load_model(MODEL_CONFIG['name_model'])
# model_quality = load_model(MODEL_CONFIG['quality_model'])
# return model_name, model_quality
# except Exception as e:
# st.error(f"Error loading classification models: {str(e)}")
# return None, None
# @st.cache_resource
# def load_detectron_model(fruit_name: str):
# """Load and cache the Detectron2 model for damage detection."""
# try:
# cfg = get_cfg()
# config_path = Path(f"{fruit_name.lower()}_config.yaml")
# model_path = Path(f"{fruit_name}_model.pth")
# if not config_path.exists() or not model_path.exists():
# raise FileNotFoundError(f"Model files not found for {fruit_name}")
# cfg.merge_from_file(str(config_path))
# cfg.MODEL.WEIGHTS = str(model_path)
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold']
# cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# return DefaultPredictor(cfg), cfg
# except Exception as e:
# st.error(f"Error loading Detectron2 model: {str(e)}")
# return None, None
# def preprocess_image(img: np.ndarray) -> tuple:
# """Preprocess the input image for model prediction."""
# try:
# # Convert to PIL Image if necessary
# if isinstance(img, np.ndarray):
# img = Image.fromarray(img.astype('uint8'), 'RGB')
# # Resize and prepare for model input
# img_resized = img.resize(MODEL_CONFIG['input_size'])
# img_array = image.img_to_array(img_resized)
# img_expanded = np.expand_dims(img_array, axis=0)
# img_normalized = img_expanded / 255.0
# return img_normalized, img_resized
# except Exception as e:
# st.error(f"Error preprocessing image: {str(e)}")
# return None, None
# def predict_fruit(img: np.ndarray) -> tuple:
# """Predict fruit type and quality."""
# model_name, model_quality = load_classification_models()
# if model_name is None or model_quality is None:
# return None, None, None
# img_normalized, img_resized = preprocess_image(img)
# if img_normalized is None:
# return None, None, None
# try:
# # Make predictions
# pred_name = model_name.predict(img_normalized)
# pred_quality = model_quality.predict(img_normalized)
# # Get predicted labels
# predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]]
# predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]]
# return predicted_name, predicted_quality, img_resized
# except Exception as e:
# st.error(f"Error making predictions: {str(e)}")
# return None, None, None
# def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
# """Detect and visualize damage in the fruit image."""
# predictor, cfg = load_detectron_model(fruit_name)
# if predictor is None or cfg is None:
# return None
# try:
# outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
# v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
# out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
# return out.get_image()
# except Exception as e:
# st.error(f"Error in damage detection: {str(e)}")
# return None
# def main():
# st.set_page_config(page_title="Fruit Quality Analysis", layout="wide")
# st.title("Automated Fruits Monitoring System")
# st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
# uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
# if uploaded_file is not None:
# # Create two columns for layout
# col1, col2 = st.columns(2)
# # Display uploaded image
# image = Image.open(uploaded_file)
# col1.image(image, caption="Uploaded Image", use_column_width=True)
# if col1.button("Analyze"):
# with st.spinner("Analyzing image..."):
# predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image))
# if predicted_name and predicted_quality:
# # Display results
# col2.markdown("### Analysis Results")
# col2.markdown(f"**Fruit Type:** {predicted_name}")
# col2.markdown(f"**Quality:** {predicted_quality}")
# # Check if damage detection is needed
# if (predicted_name.lower() in LABEL_MAPS['name'].values() and
# predicted_quality in ["Mild", "Rotten"]):
# col2.markdown("### Damage Detection")
# damage_image = detect_damage(img_resized, predicted_name)
# if damage_image is not None:
# col2.image(damage_image, caption="Detected Damage Regions",
# use_column_width=True)
# # Add download button for the damage detection result
# col2.download_button(
# label="Download Analysis Result",
# data=cv2.imencode('.png', damage_image)[1].tobytes(),
# file_name=f"{predicted_name}_damage_analysis.png",
# mime="image/png"
# )
# if __name__ == "__main__":
# main()