yunusajib commited on
Commit
35170ca
·
verified ·
1 Parent(s): 40003bf

update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -30
app.py CHANGED
@@ -1,42 +1,57 @@
1
  import streamlit as st
2
  import torch
3
  from torchvision import transforms
4
- from PIL import Image
5
  import numpy as np
6
  import time
7
 
8
- # Simplified model definition
9
  class PlantDiseaseModel(torch.nn.Module):
10
  def __init__(self, num_classes=2):
11
  super().__init__()
 
12
  self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
13
  self.model.classes = num_classes # Set your number of classes
14
 
15
  def forward(self, x):
16
  return self.model(x)
17
 
18
- # Load model (simplified for Hugging Face)
19
  @st.cache_resource
20
  def load_model():
21
- model = PlantDiseaseModel(num_classes=2) # Update with your actual class count
22
  try:
23
  model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
24
  except:
25
- st.warning("Couldn't load custom weights, using pretrained")
26
  return model
27
 
28
- # Image preprocessing
29
- def preprocess_image(image):
30
- transform = transforms.Compose([
31
- transforms.Resize(640),
32
- transforms.ToTensor(),
33
- ])
34
- return transform(image).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Main app
37
  def main():
38
  st.set_page_config(page_title="Plant Disease Detector", layout="wide")
39
- st.title("🌱 Plant Disease Detection")
40
 
41
  model = load_model()
42
 
@@ -44,32 +59,40 @@ def main():
44
  uploaded_file = st.file_uploader("Upload a plant image...", type=["jpg", "jpeg", "png"])
45
 
46
  if uploaded_file is not None:
47
- # Display original image
48
  image = Image.open(uploaded_file).convert("RGB")
49
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
50
 
51
  # Process and predict
52
  with st.spinner("Analyzing..."):
53
- # Preprocess
54
- input_tensor = preprocess_image(image)
 
 
 
 
55
 
56
  # Predict
57
  with torch.no_grad():
58
  results = model(input_tensor)
59
 
60
- # Convert results to image
61
- results_img = np.array(results.render()[0]) # Get first image with boxes
62
-
63
- # Display results
64
- st.image(results_img, caption="Detection Results", use_column_width=True)
 
 
 
 
 
 
65
 
66
- # Show predictions
67
- st.subheader("📊 Detection Results")
68
- if hasattr(results, 'pandas'):
69
- df = results.pandas().xyxy[0] # Convert to pandas DataFrame
70
- st.dataframe(df[['name', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']])
71
- else:
72
- st.write("No detections found")
73
 
74
  if __name__ == "__main__":
75
  main()
 
1
  import streamlit as st
2
  import torch
3
  from torchvision import transforms
4
+ from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
  import time
7
 
8
+ # Simplified YOLO-style model definition (Pillow-only version)
9
  class PlantDiseaseModel(torch.nn.Module):
10
  def __init__(self, num_classes=2):
11
  super().__init__()
12
+ # Example backbone (replace with your actual model architecture)
13
  self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
14
  self.model.classes = num_classes # Set your number of classes
15
 
16
  def forward(self, x):
17
  return self.model(x)
18
 
19
+ # Load model
20
  @st.cache_resource
21
  def load_model():
22
+ model = PlantDiseaseModel(num_classes=2) # Update class count
23
  try:
24
  model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
25
  except:
26
+ st.warning("Using pretrained weights (custom weights not found)")
27
  return model
28
 
29
+ # Draw bounding boxes with Pillow
30
+ def draw_boxes_pillow(image, predictions):
31
+ """Draw boxes/labels on image using Pillow only"""
32
+ draw = ImageDraw.Draw(image)
33
+ try:
34
+ font = ImageFont.load_default()
35
+ for _, row in predictions.iterrows():
36
+ xmin, ymin, xmax, ymax = row['xmin'], row['ymin'], row['xmax'], row['ymax']
37
+ label = f"{row['name']} {row['confidence']:.2f}"
38
+
39
+ # Draw rectangle
40
+ draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3)
41
+
42
+ # Draw label background
43
+ text_width, text_height = font.getsize(label)
44
+ draw.rectangle([xmin, ymin-text_height, xmin+text_width, ymin], fill="red")
45
+
46
+ # Draw text
47
+ draw.text((xmin, ymin-text_height), label, fill="white", font=font)
48
+ except Exception as e:
49
+ st.error(f"Error drawing boxes: {str(e)}")
50
+ return image
51
 
 
52
  def main():
53
  st.set_page_config(page_title="Plant Disease Detector", layout="wide")
54
+ st.title("🌱 Plant Disease Detection (Tomato or Corn Maiza)")
55
 
56
  model = load_model()
57
 
 
59
  uploaded_file = st.file_uploader("Upload a plant image...", type=["jpg", "jpeg", "png"])
60
 
61
  if uploaded_file is not None:
62
+ # Load with Pillow
63
  image = Image.open(uploaded_file).convert("RGB")
64
+ col1, col2 = st.columns(2)
65
+
66
+ with col1:
67
+ st.image(image, caption="Original Image", use_column_width=True)
68
 
69
  # Process and predict
70
  with st.spinner("Analyzing..."):
71
+ # Convert to tensor (Pillow-compatible preprocessing)
72
+ transform = transforms.Compose([
73
+ transforms.Resize(640),
74
+ transforms.ToTensor(),
75
+ ])
76
+ input_tensor = transform(image).unsqueeze(0)
77
 
78
  # Predict
79
  with torch.no_grad():
80
  results = model(input_tensor)
81
 
82
+ # Convert results to Pandas
83
+ try:
84
+ results_df = results.pandas().xyxy[0]
85
+
86
+ # Draw boxes using Pillow
87
+ output_image = image.copy()
88
+ output_image = draw_boxes_pillow(output_image, results_df)
89
+
90
+ with col2:
91
+ st.image(output_image, caption="Detection Results", use_column_width=True)
92
+ st.dataframe(results_df[['name', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']])
93
 
94
+ except Exception as e:
95
+ st.error(f"Prediction error: {str(e)}")
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  main()