Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import cv2 | |
import warnings | |
import os | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Try importing TensorFlow | |
try: | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image | |
except ImportError: | |
st.error("Failed to import TensorFlow. Please make sure it's installed correctly.") | |
# Try importing PyTorch and Detectron2 | |
try: | |
import torch | |
import detectron2 | |
except ImportError: | |
with st.spinner("Installing PyTorch and Detectron2..."): | |
os.system("pip install torch torchvision") | |
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
import torch | |
import detectron2 | |
import streamlit as st | |
import numpy as np | |
import cv2 | |
import torch | |
import os | |
from PIL import Image | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog | |
# Suppress warnings | |
import warnings | |
import tensorflow as tf | |
warnings.filterwarnings("ignore") | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
def load_models(): | |
model_name = load_model('name_model_inception.h5') | |
model_quality = load_model('type_model_inception.h5') | |
return model_name, model_quality | |
model_name, model_quality = load_models() | |
# Detectron2 setup | |
def load_detectron_model(fruit_name): | |
cfg = get_cfg() | |
config_path = os.path.join(f"{fruit_name.lower()}_config.yaml") | |
cfg.merge_from_file(config_path) | |
model_path = os.path.join(f"{fruit_name}_model.pth") | |
cfg.MODEL.WEIGHTS = model_path | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
cfg.MODEL.DEVICE = 'cpu' | |
predictor = DefaultPredictor(cfg) | |
return predictor, cfg | |
# Labels | |
label_map_name = { | |
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya", | |
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon", | |
10: "tomato" | |
} | |
label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"} | |
def predict_fruit(img): | |
# Preprocess image | |
img = Image.fromarray(img.astype('uint8'), 'RGB') | |
img = img.resize((224, 224)) | |
x = image.img_to_array(img) | |
x = np.expand_dims(x, axis=0) | |
x = x / 255.0 | |
# Predict | |
pred_name = model_name.predict(x) | |
pred_quality = model_quality.predict(x) | |
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]] | |
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]] | |
return predicted_name, predicted_quality, img | |
def main(): | |
st.title("Automated Fruits Monitoring System") | |
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.") | |
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
if st.button("Analyze"): | |
predicted_name, predicted_quality, img = predict_fruit(np.array(image)) | |
st.write(f"Fruits Type Detection: {predicted_name}") | |
st.write(f"Fruits Quality Classification: {predicted_quality}") | |
if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]: | |
st.write("Segmentation of Defective Region:") | |
try: | |
predictor, cfg = load_detectron_model(predicted_name) | |
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) | |
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8) | |
out = v.draw_instance_predictions(outputs["instances"].to("cpu")) | |
st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True) | |
except Exception as e: | |
st.error(f"Error in damage detection: {str(e)}") | |
else: | |
st.write("No damage detection performed for this fruit or quality level.") | |
if __name__ == "__main__": | |
main() |