streamlit_app / app.py
MuhammmadRizwanRizwan's picture
Update app.py
ea250f3 verified
import streamlit as st
import numpy as np
import cv2
import warnings
import os
# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Try importing TensorFlow
try:
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
except ImportError:
st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
# Try importing PyTorch and Detectron2
try:
import torch
import detectron2
except ImportError:
with st.spinner("Installing PyTorch and Detectron2..."):
os.system("pip install torch torchvision")
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
import torch
import detectron2
import streamlit as st
import numpy as np
import cv2
import torch
import os
from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
# Suppress warnings
import warnings
import tensorflow as tf
warnings.filterwarnings("ignore")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@st.cache_resource
def load_models():
model_name = load_model('name_model_inception.h5')
model_quality = load_model('type_model_inception.h5')
return model_name, model_quality
model_name, model_quality = load_models()
# Detectron2 setup
@st.cache_resource
def load_detectron_model(fruit_name):
cfg = get_cfg()
config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
cfg.merge_from_file(config_path)
model_path = os.path.join(f"{fruit_name}_model.pth")
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.DEVICE = 'cpu'
predictor = DefaultPredictor(cfg)
return predictor, cfg
# Labels
label_map_name = {
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
10: "tomato"
}
label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
def predict_fruit(img):
# Preprocess image
img = Image.fromarray(img.astype('uint8'), 'RGB')
img = img.resize((224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = x / 255.0
# Predict
pred_name = model_name.predict(x)
pred_quality = model_quality.predict(x)
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
return predicted_name, predicted_quality, img
def main():
st.title("Automated Fruits Monitoring System")
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
if st.button("Analyze"):
predicted_name, predicted_quality, img = predict_fruit(np.array(image))
st.write(f"Fruits Type Detection: {predicted_name}")
st.write(f"Fruits Quality Classification: {predicted_quality}")
if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
st.write("Segmentation of Defective Region:")
try:
predictor, cfg = load_detectron_model(predicted_name)
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
except Exception as e:
st.error(f"Error in damage detection: {str(e)}")
else:
st.write("No damage detection performed for this fruit or quality level.")
if __name__ == "__main__":
main()