bunnybala commited on
Commit
3cdad53
Β·
verified Β·
1 Parent(s): 4b1ffd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -129
app.py CHANGED
@@ -1,143 +1,47 @@
1
  import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.transforms as transforms
5
  from PIL import Image
6
- import os
7
- import requests
8
 
9
- # ======= Groq API Key (Embed Yours Here) =======
10
- GROQ_API_KEY = "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"
11
 
12
- # ======= Device & Model Setup =======
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
- class PlantDiseaseCNN(nn.Module):
16
- def __init__(self, num_classes):
17
- super(PlantDiseaseCNN, self).__init__()
18
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
19
- self.pool = nn.MaxPool2d(2, 2)
20
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
21
- self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
22
- self.fc1 = nn.Linear(64 * 32 * 32, 128)
23
- self.fc2 = nn.Linear(128, num_classes)
24
- self.relu = nn.ReLU()
25
 
26
- def forward(self, x):
27
- x = self.pool(self.relu(self.conv1(x)))
28
- x = self.pool(self.relu(self.conv2(x)))
29
- x = self.pool(self.relu(self.conv3(x)))
30
- x = x.view(x.size(0), -1)
31
- x = self.relu(self.fc1(x))
32
- x = self.fc2(x)
33
- return x
34
 
35
- # ======= Class Names for Each Crop =======
36
- class_names = {
37
- "Tomato Disease Detection": [
38
- "Bacterial spot", "Early blight", "Late blight", "Leaf Mold",
39
- "Septoria leaf spot", "Spider mites", "Target Spot", "Yellow Leaf Curl Virus",
40
- "Mosaic virus", "Healthy"
41
- ],
42
- "Pepper Disease Detection": [
43
- "Bacterial spot", "Healthy"
44
- ],
45
- "Potato Disease Detection": [
46
- "Early blight", "Late blight", "Healthy"
47
- ]
48
- }
49
 
50
- # ======= Load Models =======
51
- models_info = {
52
- "Tomato Disease Detection": "tomato-detection-model.pth",
53
- "Pepper Disease Detection": "pepper-detection-model.pth",
54
- "Potato Disease Detection": "potato-detection-model.pth"
55
- }
56
 
57
- models = {}
 
 
 
58
 
59
- for crop, model_path in models_info.items():
60
- num_classes = len(class_names[crop])
61
- model = PlantDiseaseCNN(num_classes)
62
- state_dict = torch.load(model_path, map_location=device)
63
- model.load_state_dict(state_dict)
64
- model.to(device)
65
- model.eval()
66
- models[crop] = model
67
 
68
- # ======= Preprocessing =======
69
- transform = transforms.Compose([
70
- transforms.Resize((256, 256)),
71
- transforms.ToTensor(),
72
- transforms.Normalize([0.5], [0.5])
73
- ])
74
 
75
- def predict_image(image, model, crop_name):
76
- image = image.convert("RGB")
77
- image = transform(image).unsqueeze(0).to(device)
78
- with torch.no_grad():
79
- output = model(image)
80
- _, predicted = torch.max(output, 1)
81
- return class_names[crop_name][predicted.item()]
82
 
83
- # ======= Groq LLM Remedies =======
84
- def get_remedies_from_groq(disease_name):
85
- headers = {
86
- "Authorization": f"Bearer {GROQ_API_KEY}",
87
- "Content-Type": "application/json",
88
- }
89
- payload = {
90
- "model": "llama-3.1-8b-instant",
91
- "messages": [
92
- {
93
- "role": "user",
94
- "content": f"""The plant disease is **{disease_name}**.
95
- Give me professional agricultural advice in this format:
96
- 1. Insecticides to use:
97
- 2. Pesticides to use:
98
- 3. Herbicides to use:
99
- 4. Natural remedies:
100
- 5. Preventive measures:
101
- Make it clear and line-by-line."""
102
- }
103
- ],
104
- "temperature": 0.7
105
- }
106
 
107
- response = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload)
108
- if response.status_code == 200:
109
- return response.json()["choices"][0]["message"]["content"]
110
- else:
111
- return "❌ Failed to fetch remedies. Check your API key or network connection."
112
-
113
- # ======= Streamlit UI =======
114
- st.set_page_config(page_title="Plant Disease Detection", layout="centered")
115
- st.title("🌿 Plant Disease Detection & Remedies ")
116
-
117
- tab1, tab2, tab3 = st.tabs(["πŸ… Tomato", "🌢️ Pepper", "πŸ₯” Potato"])
118
-
119
- for tab, crop in zip([tab1, tab2, tab3], models_info.keys()):
120
- with tab:
121
- st.subheader(crop)
122
- method = st.radio("Select Input Method:", ["Upload Image", "Use Camera"], key=f"{crop}_method")
123
-
124
- image = None
125
- if method == "Upload Image":
126
- uploaded_file = st.file_uploader("Upload a leaf image", type=["jpg", "jpeg", "png"], key=f"{crop}_upload")
127
- if uploaded_file:
128
- image = Image.open(uploaded_file)
129
-
130
- elif method == "Use Camera":
131
- captured_image = st.camera_input("Capture a leaf image", key=f"{crop}_camera")
132
- if captured_image:
133
- image = Image.open(captured_image)
134
-
135
- if image:
136
- st.image(image, caption="Leaf Image", use_column_width=True)
137
- prediction = predict_image(image, models[crop], crop)
138
- st.success(f"πŸ” Prediction: **{prediction}**")
139
-
140
- with st.spinner("πŸ€– Fetching remedies ..."):
141
- advice = get_remedies_from_groq(prediction)
142
- st.markdown("### 🌱 Remedies & Prevention")
143
- st.markdown(advice)
 
1
  import streamlit as st
2
+ from ultralytics import YOLO
 
 
3
  from PIL import Image
4
+ import tempfile
5
+ import numpy as np
6
 
7
+ st.set_page_config(page_title="YOLOv8 Inference App", layout="centered")
8
+ st.title("πŸ” YOLOv8 Image Detector")
9
 
10
+ MODEL_PATH = "best.pt"
 
11
 
12
+ @st.cache_resource
13
+ def load_model():
14
+ return YOLO(MODEL_PATH)
 
 
 
 
 
 
 
15
 
16
+ model = load_model()
 
 
 
 
 
 
 
17
 
18
+ # Option to choose input method
19
+ input_method = st.radio("Choose Image Input Method:", ("πŸ“€ Upload", "πŸ“· Camera"))
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ if input_method == "πŸ“€ Upload":
22
+ image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
23
+ if image_file:
24
+ image = Image.open(image_file)
 
 
25
 
26
+ elif input_method == "πŸ“· Camera":
27
+ camera_image = st.camera_input("Take a picture")
28
+ if camera_image:
29
+ image = Image.open(camera_image)
30
 
31
+ # Proceed if image is available
32
+ if ('image' in locals()):
33
+ st.image(image, caption="πŸ–ΌοΈ Input Image", use_column_width=True)
 
 
 
 
 
34
 
35
+ with st.spinner("πŸ” Detecting objects..."):
36
+ results = model(image)
37
+ result_img = results[0].plot()
 
 
 
38
 
39
+ st.image(result_img, caption="🧠 Detection Result", use_column_width=True)
 
 
 
 
 
 
40
 
41
+ # Save detection result
42
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
43
+ Image.fromarray(result_img).save(tmp.name)
44
+ st.download_button("πŸ“₯ Download Output", data=open(tmp.name, "rb").read(), file_name="detected.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ else:
47
+ st.info("Please provide an image using Upload or Camera.")