import streamlit as st | |
import numpy as np | |
import cv2 | |
import warnings | |
import os | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Try importing TensorFlow | |
try: | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image | |
except ImportError: | |
st.error("Failed to import TensorFlow. Please make sure it's installed correctly.") | |
# Try importing PyTorch and Detectron2 | |
try: | |
import torch | |
import detectron2 | |
except ImportError: | |
with st.spinner("Installing PyTorch and Detectron2..."): | |
os.system("pip install torch torchvision") | |
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
import torch | |
import detectron2 | |
import streamlit as st | |
import numpy as np | |
import cv2 | |
import torch | |
import os | |
from PIL import Image | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog | |
# Suppress warnings | |
import warnings | |
import tensorflow as tf | |
warnings.filterwarnings("ignore") | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
def load_models(): | |
model_name = load_model('name_model_inception.h5') | |
model_quality = load_model('type_model_inception.h5') | |
return model_name, model_quality | |
model_name, model_quality = load_models() | |
# Detectron2 setup | |
def load_detectron_model(fruit_name): | |
cfg = get_cfg() | |
config_path = os.path.join(f"{fruit_name.lower()}_config.yaml") | |
cfg.merge_from_file(config_path) | |
model_path = os.path.join(f"{fruit_name}_model.pth") | |
cfg.MODEL.WEIGHTS = model_path | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
cfg.MODEL.DEVICE = 'cpu' | |
predictor = DefaultPredictor(cfg) | |
return predictor, cfg | |
# Labels | |
label_map_name = { | |
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya", | |
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon", | |
10: "tomato" | |
} | |
label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"} | |
def predict_fruit(img): | |
# Preprocess image | |
img = Image.fromarray(img.astype('uint8'), 'RGB') | |
img = img.resize((224, 224)) | |
x = image.img_to_array(img) | |
x = np.expand_dims(x, axis=0) | |
x = x / 255.0 | |
# Predict | |
pred_name = model_name.predict(x) | |
pred_quality = model_quality.predict(x) | |
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]] | |
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]] | |
return predicted_name, predicted_quality, img | |
def main(): | |
st.title("Automated Fruits Monitoring System") | |
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.") | |
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
if st.button("Analyze"): | |
predicted_name, predicted_quality, img = predict_fruit(np.array(image)) | |
st.write(f"Fruits Type Detection: {predicted_name}") | |
st.write(f"Fruits Quality Classification: {predicted_quality}") | |
if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]: | |
st.write("Segmentation of Defective Region:") | |
try: | |
predictor, cfg = load_detectron_model(predicted_name) | |
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) | |
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8) | |
out = v.draw_instance_predictions(outputs["instances"].to("cpu")) | |
st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True) | |
except Exception as e: | |
st.error(f"Error in damage detection: {str(e)}") | |
else: | |
st.write("No damage detection performed for this fruit or quality level.") | |
if __name__ == "__main__": | |
main() | |
# import streamlit as st | |
# import numpy as np | |
# import cv2 | |
# import warnings | |
# import os | |
# from pathlib import Path | |
# from PIL import Image | |
# import tensorflow as tf | |
# from tensorflow.keras.models import load_model | |
# from tensorflow.keras.preprocessing import image | |
# from detectron2.engine import DefaultPredictor | |
# from detectron2.config import get_cfg | |
# from detectron2.utils.visualizer import Visualizer | |
# from detectron2.data import MetadataCatalog | |
# # Suppress warnings | |
# warnings.filterwarnings("ignore") | |
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
# # Configuration | |
# MODEL_CONFIG = { | |
# 'name_model': 'name_model_inception.h5', | |
# 'quality_model': 'type_model_inception.h5', | |
# 'input_size': (224, 224), | |
# 'score_threshold': 0.5 | |
# } | |
# LABEL_MAPS = { | |
# 'name': { | |
# 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya", | |
# 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon", | |
# 10: "tomato" | |
# }, | |
# 'quality': {0: "Good", 1: "Mild", 2: "Rotten"} | |
# } | |
# @st.cache_resource | |
# def load_classification_models(): | |
# """Load and cache the classification models.""" | |
# try: | |
# model_name = load_model(MODEL_CONFIG['name_model']) | |
# model_quality = load_model(MODEL_CONFIG['quality_model']) | |
# return model_name, model_quality | |
# except Exception as e: | |
# st.error(f"Error loading classification models: {str(e)}") | |
# return None, None | |
# @st.cache_resource | |
# def load_detectron_model(fruit_name: str): | |
# """Load and cache the Detectron2 model for damage detection.""" | |
# try: | |
# cfg = get_cfg() | |
# config_path = Path(f"{fruit_name.lower()}_config.yaml") | |
# model_path = Path(f"{fruit_name}_model.pth") | |
# if not config_path.exists() or not model_path.exists(): | |
# raise FileNotFoundError(f"Model files not found for {fruit_name}") | |
# cfg.merge_from_file(str(config_path)) | |
# cfg.MODEL.WEIGHTS = str(model_path) | |
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold'] | |
# cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# return DefaultPredictor(cfg), cfg | |
# except Exception as e: | |
# st.error(f"Error loading Detectron2 model: {str(e)}") | |
# return None, None | |
# def preprocess_image(img: np.ndarray) -> tuple: | |
# """Preprocess the input image for model prediction.""" | |
# try: | |
# # Convert to PIL Image if necessary | |
# if isinstance(img, np.ndarray): | |
# img = Image.fromarray(img.astype('uint8'), 'RGB') | |
# # Resize and prepare for model input | |
# img_resized = img.resize(MODEL_CONFIG['input_size']) | |
# img_array = image.img_to_array(img_resized) | |
# img_expanded = np.expand_dims(img_array, axis=0) | |
# img_normalized = img_expanded / 255.0 | |
# return img_normalized, img_resized | |
# except Exception as e: | |
# st.error(f"Error preprocessing image: {str(e)}") | |
# return None, None | |
# def predict_fruit(img: np.ndarray) -> tuple: | |
# """Predict fruit type and quality.""" | |
# model_name, model_quality = load_classification_models() | |
# if model_name is None or model_quality is None: | |
# return None, None, None | |
# img_normalized, img_resized = preprocess_image(img) | |
# if img_normalized is None: | |
# return None, None, None | |
# try: | |
# # Make predictions | |
# pred_name = model_name.predict(img_normalized) | |
# pred_quality = model_quality.predict(img_normalized) | |
# # Get predicted labels | |
# predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]] | |
# predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]] | |
# return predicted_name, predicted_quality, img_resized | |
# except Exception as e: | |
# st.error(f"Error making predictions: {str(e)}") | |
# return None, None, None | |
# def detect_damage(img: Image, fruit_name: str) -> np.ndarray: | |
# """Detect and visualize damage in the fruit image.""" | |
# predictor, cfg = load_detectron_model(fruit_name) | |
# if predictor is None or cfg is None: | |
# return None | |
# try: | |
# outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) | |
# v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8) | |
# out = v.draw_instance_predictions(outputs["instances"].to("cpu")) | |
# return out.get_image() | |
# except Exception as e: | |
# st.error(f"Error in damage detection: {str(e)}") | |
# return None | |
# def main(): | |
# st.set_page_config(page_title="Fruit Quality Analysis", layout="wide") | |
# st.title("Automated Fruits Monitoring System") | |
# st.write("Upload an image of a fruit to detect its type, quality, and potential damage.") | |
# uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"]) | |
# if uploaded_file is not None: | |
# # Create two columns for layout | |
# col1, col2 = st.columns(2) | |
# # Display uploaded image | |
# image = Image.open(uploaded_file) | |
# col1.image(image, caption="Uploaded Image", use_column_width=True) | |
# if col1.button("Analyze"): | |
# with st.spinner("Analyzing image..."): | |
# predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image)) | |
# if predicted_name and predicted_quality: | |
# # Display results | |
# col2.markdown("### Analysis Results") | |
# col2.markdown(f"**Fruit Type:** {predicted_name}") | |
# col2.markdown(f"**Quality:** {predicted_quality}") | |
# # Check if damage detection is needed | |
# if (predicted_name.lower() in LABEL_MAPS['name'].values() and | |
# predicted_quality in ["Mild", "Rotten"]): | |
# col2.markdown("### Damage Detection") | |
# damage_image = detect_damage(img_resized, predicted_name) | |
# if damage_image is not None: | |
# col2.image(damage_image, caption="Detected Damage Regions", | |
# use_column_width=True) | |
# # Add download button for the damage detection result | |
# col2.download_button( | |
# label="Download Analysis Result", | |
# data=cv2.imencode('.png', damage_image)[1].tobytes(), | |
# file_name=f"{predicted_name}_damage_analysis.png", | |
# mime="image/png" | |
# ) | |
# if __name__ == "__main__": | |
# main() |