Spaces:
Runtime error
Runtime error
import streamlit as st | |
import sys | |
import openai | |
import toml | |
from openai import OpenAI | |
import pandas as pd | |
import os | |
import random | |
import glob | |
import re | |
from io import BytesIO | |
from six import BytesIO | |
import cv2 | |
import warnings | |
warnings.filterwarnings('ignore') | |
from io import BytesIO | |
import tempfile | |
import time | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as mcolors | |
import seaborn as sns | |
from PIL import Image | |
from PIL import ImageColor | |
from PIL import ImageDraw | |
from PIL import ImageFont | |
from PIL import ImageOps | |
import json | |
import numpy as np | |
np.random.seed(42) | |
import tensorflow as tf | |
tf.random.set_seed(42) | |
import tensorflow.keras as k | |
k.utils.set_random_seed(42) # idem keras | |
from keras.backend import manual_variable_initialization | |
manual_variable_initialization(True) # https://github.com/keras-team/keras/issues/4875#issuecomment-296696536 | |
from tensorflow.keras.applications.xception import preprocess_input | |
from tensorflow.keras.applications.xception import Xception | |
from scipy.stats import mode | |
from tensorflow.keras.applications.mobilenet import MobileNet | |
from tensorflow.keras.applications.mobilenet import preprocess_input as mobilenet_preprocess | |
from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess | |
import tensorflow_hub as hub | |
def load_models(): | |
#OpenAI elements | |
#secrets = toml.load(".vscode/streamlit/secrets.toml") | |
#client_d = OpenAI(api_key = secrets["OPENAI_API_KEY"]) | |
client_d = OpenAI(api_key = st.secrets["OPENAI_API_KEY"]) | |
module_handle = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1" | |
detector_d = hub.load(module_handle).signatures['default']; | |
file_path = '.vscode/inputs/' # folder with files | |
Dis_percentage_d = pd.read_csv(os.path.join(file_path,'Spots_Percentage_results.csv')) | |
Details_d = pd.read_csv(os.path.join(file_path,'Plant_details.csv')) | |
# Load the TensorFlow Lite model | |
#model_path = '.vscode/model/model.tflite' | |
#interpreter = tf.lite.Interpreter(model_path=model_path) | |
#interpreter.allocate_tensors() | |
print("Loading CNN Model") | |
model3_path = '.vscode/model/CNN_0424.keras' | |
model3_weights_path = '.vscode/model/CNN_weights.hdf5' | |
cnn_model_d = k.models.load_model(model3_weights_path) | |
print("Loading Xception Model") | |
model1_path = '.vscode/model/XCeption_weights.hdf5' | |
xception_model_d = k.models.load_model(model1_path) | |
print("Loading Mobilenet Model") | |
model2_path = '.vscode/model/MobileNet_weights.hdf5' | |
mobilenet_model_d = k.models.load_model(model2_path) | |
print("finished loading models") | |
with open('.vscode/inputs/Xception_0422_labels.json', 'r') as file: | |
loaded_class_indices = {k: int(v) for k, v in json.load(file).items()} | |
class_labels_d = {value: key for key, value in loaded_class_indices.items()} # Convert keys to int | |
#xception_model.weights[-1] | |
#mobilenet_model.weights[-1] | |
#cnn_model.weights[-1] | |
return client_d,detector_d,Dis_percentage_d,Details_d,cnn_model_d,xception_model_d,mobilenet_model_d,class_labels_d | |
# Loading the models. load_models() methos is cached and will be loaded only once during the initial boot. | |
client,detector,Dis_percentage,Details,cnn_model,xception_model,mobilenet_model,class_labels = load_models() | |
# Identify extent of spot or lesion coverage on leaf | |
def identify_spots_or_lesions(img): | |
cv_image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
lab_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2Lab) | |
l_channel, a_channel, b_channel = cv2.split(lab_image) | |
blur = cv2.GaussianBlur(a_channel,(3,3),0) | |
thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1] | |
# Morphological clean-up | |
kernel = np.ones((3,3), np.uint8) | |
cleaned = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1) # Opening = erosion followed by dilation | |
edges = cv2.Canny(cleaned,100,300) | |
# Filter and contours | |
contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
max_area = 18000 | |
filtered_contours = [cnt for cnt in contours if cv2.contourArea(cnt) < max_area] | |
# Calculate the percentage of spots/lesions | |
spot_pixels = sum(cv2.contourArea(cnt) for cnt in filtered_contours) | |
total_pixels = edges.shape[0] * edges.shape[1] | |
percentage_spots = (spot_pixels / total_pixels)*100 | |
st.write(f"Percentage of spots/lesions: {percentage_spots:.2f}%") | |
# Draw filtered contours | |
contoured_image = cv2.drawContours(cv_image.copy(), filtered_contours, -1, (0, 255, 0), 1) | |
# Visualization | |
mfig = plt.figure(figsize=(25, 8)) | |
plt.subplot(1, 5, 1) | |
plt.imshow(cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)) | |
plt.title('Original Image') | |
plt.subplot(1, 5, 2) | |
plt.imshow(a_channel, cmap='gray') | |
plt.title('LAB - A channel') | |
plt.subplot(1, 5, 3) | |
plt.imshow(edges, cmap='gray') | |
plt.title('Edge Detection') | |
plt.subplot(1, 5, 4) | |
plt.imshow(cleaned, cmap='gray') | |
plt.title('Thresholded & Cleaned') | |
plt.subplot(1, 5, 5) | |
plt.imshow(cv2.cvtColor(contoured_image, cv2.COLOR_BGR2RGB)) | |
plt.title('Spots or Lesions Identified') | |
#plt.show() | |
st.pyplot(mfig) | |
return(percentage_spots) | |
# Plot disease percentage | |
def plot_dis_percentage(row, percentage): | |
# Determine the range category for the title | |
if percentage < row['Q1']: | |
category = 'Mild' | |
color = 'yellow' | |
elif row['Q1'] <= percentage <= row['Q3']: | |
category = 'Moderate' | |
color = 'orange' | |
else: | |
category = 'Severe' | |
color = 'darkred' | |
# Normalize the data to the range of [0, 1] | |
min_val = row['min'] | |
max_val = row['max'] | |
range_val = max_val - min_val | |
percentage_norm = (percentage - min_val) / range_val | |
# Create a figure and a set of subplots | |
fig, ax = plt.subplots(figsize=(6, 1)) | |
# Create the ranges for Low, Medium, and High | |
ax.axhline(0, xmin=0, xmax=(row['Q1'] - min_val) / range_val, color='yellow', linewidth=4, label='Mild') | |
ax.axhline(0, xmin=(row['Q1'] - min_val) / range_val, xmax=(row['Q3'] - min_val) / range_val, color='orange', linewidth=4, label='Moderate') | |
ax.axhline(0, xmin=(row['Q3'] - min_val) / range_val, xmax=1, color='darkred', linewidth=4, label='Severe') | |
# Plot the actual percentage as an arrow | |
ax.annotate('', xy=(percentage_norm, 0.1), xytext=(percentage_norm, -0.1), | |
arrowprops=dict(facecolor=color, shrink=0.05, width=1, headwidth=10)) | |
# Set display parameters | |
ax.set_yticks([]) # No y-ticks | |
ax.set_xticks([]) # Remove specific percentage figures from the x-axis | |
ax.set_xlim([0, 1]) # Set x-limits to normalized range | |
titlet = f'{category} - {row["Plant"]}' | |
ax.set_title(titlet) | |
ax.set_xlabel('Value (Normalized)') | |
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) | |
plt.tight_layout() | |
st.pyplot(fig) | |
return titlet | |
def resize_image(image, target_size=(224, 224)): | |
return image.resize(target_size) | |
# Classify the image | |
def classify_image(image): | |
# Convert PIL Image to a NumPy array | |
image_np = np.array(image) | |
# Preprocess the image as needed | |
resized_image = cv2.resize(image_np, (224, 224), interpolation=cv2.INTER_LINEAR) | |
img_array = np.array(resized_image, dtype='float32') | |
img_array = np.expand_dims(img_array, axis=0) | |
img_batch = np.tile(img_array, (32, 1, 1, 1)) | |
# preprocess_input from Xception to scale the image to -1 to +1 | |
#img_array = preprocess_input(img_array) | |
mobilenet_input = mobilenet_preprocess(np.copy(img_batch)) | |
xception_input = xception_preprocess(np.copy(img_batch)) | |
cnn_input = img_batch / 255.0 # normalization for generic CNN model | |
# Predict using the models | |
mobilenet_preds = mobilenet_model(mobilenet_input, training = False) | |
xception_preds = xception_model(xception_input, training = False) | |
cnn_preds = cnn_model(cnn_input, training = False) | |
# Get the most likely class index from predictions | |
mobilenet_class = np.argmax(mobilenet_preds, axis=1) | |
xception_class = np.argmax(xception_preds, axis=1) | |
cnn_class = np.argmax(cnn_preds, axis=1) | |
# -------------------------------- | |
# mean probabilities from each model | |
averaged_probs = (mobilenet_preds + xception_preds + cnn_preds) / 3 | |
averaged_probs_np = averaged_probs.numpy() | |
# top two most likely class indices | |
top_two_probs_indices = np.argsort(-averaged_probs_np, axis=1)[:, :2] | |
top_class_index = top_two_probs_indices[:, 0] | |
second_class_index = top_two_probs_indices[:, 1] | |
top_class_prob = np.max(averaged_probs_np, axis=1) | |
second_class_prob = averaged_probs_np[np.arange(top_class_index.size), second_class_index] | |
predicted_class_name = class_labels[top_class_index[0]] | |
second_class_name = class_labels[second_class_index[0]] | |
# -------------------------------- | |
st.write("Image class:", predicted_class_name) | |
st.write(f"Confidence: {top_class_prob[0]:.2%}") | |
if "healthy" in predicted_class_name: | |
st.write(f"{predicted_class_name} is healthy, skipping further analysis.") | |
return | |
else: | |
if "Background_without_leaves" in predicted_class_name: | |
st.write(f"{predicted_class_name} is not recognized as a plant image, skipping further analysis.") | |
return | |
else: | |
spots_percentage = identify_spots_or_lesions(image) | |
if predicted_class_name in Dis_percentage['Plant'].values: | |
row = Dis_percentage.loc[Dis_percentage['Plant'] == predicted_class_name].iloc[0] | |
severity_disease = plot_dis_percentage(row, spots_percentage) | |
if predicted_class_name in Details['Plant'].values: | |
row = Details.loc[Details['Plant'] == predicted_class_name].iloc[0] | |
#st.write("Disease Identification:", row[4]) | |
st.write("----------------------------------") | |
#st.write("Management:", row[5]) | |
#st.markdown(severity_disease) | |
return severity_disease, top_class_prob[0], second_class_name | |
else: | |
st.write("No data available for this plant disease in DataFrame.") | |
if top_class_prob[0] < 0.999: # threshold close to 1 to handle floating-point precision issues | |
st.write("Second predicted class:", second_class_name) | |
st.write(f"Second class confidence: {second_class_prob[0]:.3%}") | |
else: | |
st.write("Second predicted class: None") | |
return | |
def display_image(image): | |
fig = plt.figure(figsize=(12, 6)) | |
plt.grid(False) | |
plt.imshow(image) | |
plt.show() | |
def draw_bounding_box_on_image(image, ymin, xmin, ymax, xmax, color, font, thickness=4, display_str_list=()): | |
"""Adds a bounding box to an image.""" | |
draw = ImageDraw.Draw(image) | |
im_width, im_height = image.size | |
(left, right, top, bottom) = (xmin * im_width, xmax * im_width, ymin * im_height, ymax * im_height) | |
draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=thickness, fill=color) | |
# height of the display strings added to the top of the bounding | |
# box exceeds the top of the image - stack below: | |
display_str_heights = [font.getbbox(ds)[3] for ds in display_str_list] | |
total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) | |
if top > total_display_str_height: | |
text_bottom = top | |
else: | |
text_bottom = top + total_display_str_height | |
# Reverse list and print from bottom to top. | |
for display_str in display_str_list[::-1]: | |
bbox = font.getbbox(display_str) | |
text_width, text_height = bbox[2], bbox[3] | |
margin = np.ceil(0.05 * text_height) | |
draw.rectangle([(left, text_bottom - text_height - 2 * margin), (left + text_width, text_bottom)], fill=color) | |
draw.text((left + margin, text_bottom - text_height - margin), display_str, fill="black", font=font) | |
text_bottom -= text_height - 2 * margin | |
def draw_boxes(image, boxes, class_names, scores, max_boxes=3, min_score=0.1): | |
#"""Overlay labeled boxes on an image with formatted scores and label names.""" | |
colors = list(ImageColor.colormap.values()) | |
font = ImageFont.load_default() | |
# Prepare a list of all detections that meet the score threshold | |
filtered_boxes = [(boxes[i], scores[i], class_names[i]) for i in range(len(scores)) if scores[i] >= min_score] | |
# Sort detections based on scores in descending order | |
filtered_boxes.sort(key=lambda x: x[1], reverse=False) | |
# Process each box to draw (limited by max_boxes) | |
for i, (box, score, class_name) in enumerate(filtered_boxes[:max_boxes]): | |
ymin, xmin, ymax, xmax = tuple(box) | |
display_str = "{}: {:.2f}%".format(class_name.decode("ascii"), score * 100) | |
color = colors[hash(class_name) % len(colors)] | |
draw_bounding_box_on_image( image, ymin, xmin, ymax, xmax, color, font, display_str_list=[display_str]) | |
# Convert PIL Image back to numpy array for display (if necessary) | |
return np.array(image) if isinstance(image, Image.Image) else image | |
# ----------------------------------------------------------------------------------------------------// | |
# Streamlit app | |
def openai_remedy(searchval): | |
completion = client.chat.completions.create( | |
model="gpt-4-turbo", | |
messages=[ | |
{"role": "user", "content": "List out the most relevant remediation steps for " + searchval + " in 7 bullet points"} | |
], | |
temperature=0.1, | |
max_tokens=2000, | |
top_p=0.1 | |
) | |
st.markdown(completion.choices[0].message.content) | |
#st.markdown(completion.choices[0].delta.content) | |
return | |
tab1, tab2, tab3 = st.tabs(["Home", "Solution", "Team"]) | |
#First Tab: Title of Application and description | |
with tab1: | |
st.title("Plant Disease Identification") | |
# Display Plant Care Icon | |
st.image(".vscode/inputs/plantIcon.jpg", width=100) | |
st.markdown("Plant diseases are a significant threat to agricultural productivity worldwide, causing substantial crop losses and economic damage. These diseases can be caused by various factors, including fungi, bacteria, viruses, and environmental stressors. Recognizing the symptoms of plant diseases early is crucial for implementing effective management strategies and minimizing the impact on crop yield and quality.") | |
# Importance of Early Detection | |
st.write(""" | |
### Importance of Early Detection | |
Early detection of plant diseases is paramount for farmers to protect their crops and livelihoods. By identifying diseases at their onset, farmers can implement timely interventions, such as targeted pesticide applications or cultural practices, to prevent the spread of diseases and reduce crop losses. Early detection also reduces the need for excessive chemical inputs, promoting sustainable agriculture practices and environmental stewardship. | |
""") | |
# Types of Plant Diseases Detected | |
st.image(".vscode/inputs/Plant-disease-classifier-with-ai-blog-banner.jpg", width=700) | |
st.write("With more than 50% of the population in India still relying on agriculture and with the average farm sizes and incomes being very small, we believe that cost effective solutions for early detection and treatment solutions for disease could significantly improve the quality of produce and lives of farmers. With smartphones being ubiquitous, we believe providing solutions to farmers over a smartphone is the most penetrative form.") | |
#Second Tab: Image upload and disease detection and remidy susgestions | |
with tab2: | |
st.title("Plant classification, Disease detection and management") | |
# Load and display the image | |
uploaded_file = st.file_uploader("Upload Leaf Image...", type=["jpg", "jpeg", "png"], key="uploader") | |
if uploaded_file is not None: | |
print("Image successfully uploaded!") | |
# Read the uploaded image file | |
#st.image(uploaded_file, caption='Uploaded Image', use_column_width=True,width=100) | |
st.image(uploaded_file, caption='Uploaded Image', width=300) | |
image = Image.open(uploaded_file) | |
image_for_drawing = image.copy() | |
# convert PIL format to TensorFlow format | |
img = tf.convert_to_tensor(image) | |
converted_img = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...] #scales 0-1 | |
start_time = time.time() | |
result = detector(converted_img) | |
end_time = time.time() | |
result = {key: value.numpy() for key, value in result.items()} | |
#st.write("Found %d objects." % len(result["detection_scores"])) | |
#st.write("Inference time: ", end_time - start_time) | |
detection_scores = result["detection_scores"] | |
detection_class_entities = result["detection_class_entities"] | |
# Class Detections displays | |
image_with_boxes = draw_boxes(image_for_drawing, result["detection_boxes"],detection_class_entities, detection_scores) | |
#display_image(image_with_boxes) | |
st.image(image_with_boxes, caption='Uploaded Image', width=300) | |
top_3_idx = np.argsort(-detection_scores)[:3] | |
for idx in top_3_idx: | |
entity = detection_class_entities[idx].decode('utf-8') | |
if "Plant" == entity: | |
plant_score = detection_scores[idx] | |
st.write(f"Plant Probability score using Faster R-CNN Inception Resnet V2 Object detection model : {plant_score:.2%}") | |
result1 = classify_image(image) | |
if result1 is not None: | |
#st.markdown("Result " + result) | |
new1 = result1[0] + "" | |
newresult = new1.replace("_"," ") | |
newresult2 = newresult.replace("-"," ") | |
st.markdown("Fetching disease management steps for " + ":red[" + newresult2 + "]... :eyes:") | |
openai_remedy(newresult2) | |
else: | |
print("No file uploaded.") | |
# Disclaimer | |
st.write(""" | |
### Disclaimer | |
While our disease identification system strives for accuracy and reliability, it is essential to note its limitations. False positives or false negatives may occur, and users are encouraged to consult with agricultural experts for professional advice and decision-making. | |
""") | |
# Third Tab | |
with tab3: | |
st.title("CDS Batch 6 - Group 2:") | |
st.divider() | |
st.write("Abhinav Singh") | |
st.divider() | |
st.write("Ankit Kourav") | |
st.divider() | |
st.write("Challoju Anurag.") | |
st.divider() | |
st.write("Madhucchand Darbha") | |
st.divider() | |
st.write("Neha Gupta") | |
st.divider() | |
st.write("Pradeep Rajagopal") | |
st.divider() | |
st.write("Rakesh Vegesana") | |
st.divider() | |
st.write("Sachin Sharma") | |
st.divider() | |
st.write("Shashank Srivastava") | |
st.divider() | |