Spaces:
Sleeping
Sleeping
update app.py
Browse files
app.py
CHANGED
@@ -1,42 +1,57 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
from torchvision import transforms
|
4 |
-
from PIL import Image
|
5 |
import numpy as np
|
6 |
import time
|
7 |
|
8 |
-
# Simplified model definition
|
9 |
class PlantDiseaseModel(torch.nn.Module):
|
10 |
def __init__(self, num_classes=2):
|
11 |
super().__init__()
|
|
|
12 |
self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
13 |
self.model.classes = num_classes # Set your number of classes
|
14 |
|
15 |
def forward(self, x):
|
16 |
return self.model(x)
|
17 |
|
18 |
-
# Load model
|
19 |
@st.cache_resource
|
20 |
def load_model():
|
21 |
-
model = PlantDiseaseModel(num_classes=2) # Update
|
22 |
try:
|
23 |
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
|
24 |
except:
|
25 |
-
st.warning("
|
26 |
return model
|
27 |
|
28 |
-
#
|
29 |
-
def
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
# Main app
|
37 |
def main():
|
38 |
st.set_page_config(page_title="Plant Disease Detector", layout="wide")
|
39 |
-
st.title("🌱 Plant Disease Detection")
|
40 |
|
41 |
model = load_model()
|
42 |
|
@@ -44,32 +59,40 @@ def main():
|
|
44 |
uploaded_file = st.file_uploader("Upload a plant image...", type=["jpg", "jpeg", "png"])
|
45 |
|
46 |
if uploaded_file is not None:
|
47 |
-
#
|
48 |
image = Image.open(uploaded_file).convert("RGB")
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
# Process and predict
|
52 |
with st.spinner("Analyzing..."):
|
53 |
-
#
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# Predict
|
57 |
with torch.no_grad():
|
58 |
results = model(input_tensor)
|
59 |
|
60 |
-
# Convert results to
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
if hasattr(results, 'pandas'):
|
69 |
-
df = results.pandas().xyxy[0] # Convert to pandas DataFrame
|
70 |
-
st.dataframe(df[['name', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']])
|
71 |
-
else:
|
72 |
-
st.write("No detections found")
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
from torchvision import transforms
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
import numpy as np
|
6 |
import time
|
7 |
|
8 |
+
# Simplified YOLO-style model definition (Pillow-only version)
|
9 |
class PlantDiseaseModel(torch.nn.Module):
|
10 |
def __init__(self, num_classes=2):
|
11 |
super().__init__()
|
12 |
+
# Example backbone (replace with your actual model architecture)
|
13 |
self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
14 |
self.model.classes = num_classes # Set your number of classes
|
15 |
|
16 |
def forward(self, x):
|
17 |
return self.model(x)
|
18 |
|
19 |
+
# Load model
|
20 |
@st.cache_resource
|
21 |
def load_model():
|
22 |
+
model = PlantDiseaseModel(num_classes=2) # Update class count
|
23 |
try:
|
24 |
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
|
25 |
except:
|
26 |
+
st.warning("Using pretrained weights (custom weights not found)")
|
27 |
return model
|
28 |
|
29 |
+
# Draw bounding boxes with Pillow
|
30 |
+
def draw_boxes_pillow(image, predictions):
|
31 |
+
"""Draw boxes/labels on image using Pillow only"""
|
32 |
+
draw = ImageDraw.Draw(image)
|
33 |
+
try:
|
34 |
+
font = ImageFont.load_default()
|
35 |
+
for _, row in predictions.iterrows():
|
36 |
+
xmin, ymin, xmax, ymax = row['xmin'], row['ymin'], row['xmax'], row['ymax']
|
37 |
+
label = f"{row['name']} {row['confidence']:.2f}"
|
38 |
+
|
39 |
+
# Draw rectangle
|
40 |
+
draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3)
|
41 |
+
|
42 |
+
# Draw label background
|
43 |
+
text_width, text_height = font.getsize(label)
|
44 |
+
draw.rectangle([xmin, ymin-text_height, xmin+text_width, ymin], fill="red")
|
45 |
+
|
46 |
+
# Draw text
|
47 |
+
draw.text((xmin, ymin-text_height), label, fill="white", font=font)
|
48 |
+
except Exception as e:
|
49 |
+
st.error(f"Error drawing boxes: {str(e)}")
|
50 |
+
return image
|
51 |
|
|
|
52 |
def main():
|
53 |
st.set_page_config(page_title="Plant Disease Detector", layout="wide")
|
54 |
+
st.title("🌱 Plant Disease Detection (Tomato or Corn Maiza)")
|
55 |
|
56 |
model = load_model()
|
57 |
|
|
|
59 |
uploaded_file = st.file_uploader("Upload a plant image...", type=["jpg", "jpeg", "png"])
|
60 |
|
61 |
if uploaded_file is not None:
|
62 |
+
# Load with Pillow
|
63 |
image = Image.open(uploaded_file).convert("RGB")
|
64 |
+
col1, col2 = st.columns(2)
|
65 |
+
|
66 |
+
with col1:
|
67 |
+
st.image(image, caption="Original Image", use_column_width=True)
|
68 |
|
69 |
# Process and predict
|
70 |
with st.spinner("Analyzing..."):
|
71 |
+
# Convert to tensor (Pillow-compatible preprocessing)
|
72 |
+
transform = transforms.Compose([
|
73 |
+
transforms.Resize(640),
|
74 |
+
transforms.ToTensor(),
|
75 |
+
])
|
76 |
+
input_tensor = transform(image).unsqueeze(0)
|
77 |
|
78 |
# Predict
|
79 |
with torch.no_grad():
|
80 |
results = model(input_tensor)
|
81 |
|
82 |
+
# Convert results to Pandas
|
83 |
+
try:
|
84 |
+
results_df = results.pandas().xyxy[0]
|
85 |
+
|
86 |
+
# Draw boxes using Pillow
|
87 |
+
output_image = image.copy()
|
88 |
+
output_image = draw_boxes_pillow(output_image, results_df)
|
89 |
+
|
90 |
+
with col2:
|
91 |
+
st.image(output_image, caption="Detection Results", use_column_width=True)
|
92 |
+
st.dataframe(results_df[['name', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']])
|
93 |
|
94 |
+
except Exception as e:
|
95 |
+
st.error(f"Prediction error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
if __name__ == "__main__":
|
98 |
main()
|