yunusajib commited on
Commit
2436faa
·
verified ·
1 Parent(s): 8c474ce

update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -207
app.py CHANGED
@@ -1,217 +1,25 @@
1
- # streamlit_app/app.py
2
  import streamlit as st
3
  import torch
4
  from torchvision import transforms
5
  from PIL import Image
6
- import yaml
7
- from pathlib import Path
8
- import sys
9
- import os
10
  import time
11
  from datetime import datetime
12
  import pandas as pd
13
-
14
- # Add project root to path and get absolute paths
15
- ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
- CONFIG_PATH = ROOT_DIR / 'configs' / 'config.yaml'
17
- MODEL_PATH = ROOT_DIR / 'models' / 'saved_models' / 'best_model.pth'
18
-
19
- sys.path.append(str(ROOT_DIR))
20
- from src.model import PlantDiseaseModel
21
-
22
- # Care tips and utilities
23
- plant_care_tips = {
24
- "Corn_(maize)___healthy": {
25
- "short_term": [
26
- "Continue regular watering schedule",
27
- "Monitor for any changes in leaf color",
28
- "Maintain good air circulation"
29
- ],
30
- "long_term": [
31
- "Regular soil testing",
32
- "Crop rotation planning",
33
- "Preventive pest management"
34
- ]
35
- },
36
- "Tomato___Late_blight": {
37
- "short_term": [
38
- "Remove infected leaves immediately",
39
- "Apply appropriate fungicide",
40
- "Improve air circulation around plants"
41
- ],
42
- "long_term": [
43
- "Use resistant varieties next season",
44
- "Improve soil drainage",
45
- "Practice crop rotation"
46
- ]
47
- }
48
- }
49
-
50
- def load_model():
51
- """Load the trained model and class mappings"""
52
- try:
53
- # Load config
54
- with open(CONFIG_PATH) as f:
55
- config = yaml.safe_load(f)
56
-
57
- print(f"Loading model with {config['model']['num_classes']} classes")
58
-
59
- # Initialize model
60
- model = PlantDiseaseModel(num_classes=config['model']['num_classes'])
61
-
62
- # Load trained weights
63
- print(f"Loading checkpoint from: {MODEL_PATH}")
64
- checkpoint = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
65
-
66
- model.load_state_dict(checkpoint['model_state_dict'])
67
- model.eval()
68
-
69
- return model, checkpoint['class_to_idx']
70
-
71
- except Exception as e:
72
- print(f"Error in load_model: {str(e)}")
73
- raise e
74
-
75
- def process_image(image):
76
- """Process and display image transformation steps"""
77
- st.write("🔍 Image Processing Steps:")
78
- col1, col2, col3 = st.columns(3)
79
-
80
- with col1:
81
- resized = transforms.Resize(256)(image)
82
- st.image(resized, caption="Resized Image")
83
-
84
- with col2:
85
- cropped = transforms.CenterCrop(224)(resized)
86
- st.image(cropped, caption="Cropped Image")
87
-
88
- with col3:
89
- st.write("Final Processing:")
90
- st.write("• Converted to tensor")
91
- st.write("• Normalized")
92
-
93
- transform = transforms.Compose([
94
- transforms.Resize(256),
95
- transforms.CenterCrop(224),
96
- transforms.ToTensor(),
97
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
98
- ])
99
- return transform(image).unsqueeze(0)
100
-
101
- def main():
102
- st.set_page_config(page_title="Plant Disease Classifier", layout="wide")
103
 
104
- # Initialize session state
105
- if 'history' not in st.session_state:
106
- st.session_state.history = []
107
-
108
- # Create tabs
109
- tab1, tab2 = st.tabs(["Classifier", "Model Info"])
110
-
111
- with tab1:
112
- st.title("🌿 Plant Disease Classifier")
113
-
114
- try:
115
- model, class_to_idx = load_model()
116
- idx_to_class = {v: k for k, v in class_to_idx.items()}
117
-
118
- except Exception as e:
119
- st.error(f"Error loading model: {str(e)}")
120
- st.write("Debugging info:")
121
- st.write(f"Config path exists: {CONFIG_PATH.exists()}")
122
- st.write(f"Model path exists: {MODEL_PATH.exists()}")
123
- return
124
-
125
- col1, col2 = st.columns([1, 1])
126
-
127
- with col1:
128
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
129
-
130
- if uploaded_file is not None:
131
- image = Image.open(uploaded_file).convert('RGB')
132
- st.image(image, caption='Uploaded Image', use_column_width=True)
133
-
134
- with col2:
135
- if uploaded_file is not None:
136
- with st.spinner('Analyzing image...'):
137
- progress_bar = st.progress(0)
138
- for i in range(100):
139
- time.sleep(0.01)
140
- progress_bar.progress(i + 1)
141
-
142
- try:
143
- input_tensor = process_image(image)
144
-
145
- with torch.no_grad():
146
- output = model(input_tensor)
147
- probabilities = torch.nn.functional.softmax(output, dim=1)
148
- predicted_idx = output.argmax(1).item()
149
- confidence = probabilities[0][predicted_idx].item()
150
-
151
- predicted_class = idx_to_class[predicted_idx]
152
-
153
- # Store prediction in history
154
- st.session_state.history.append({
155
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
156
- 'prediction': predicted_class,
157
- 'confidence': confidence
158
- })
159
-
160
- # Show prediction and confidence
161
- st.subheader("📊 Analysis Results")
162
- if "healthy" in predicted_class.lower():
163
- st.success(f"🌱 Prediction: {predicted_class.replace('_', ' ')}")
164
- else:
165
- st.warning(f"⚠️ Prediction: {predicted_class.replace('_', ' ')}")
166
-
167
- # Show class probabilities
168
- for idx, prob in enumerate(probabilities[0]):
169
- class_name = idx_to_class[idx].replace('_', ' ')
170
- st.write(f"{class_name}: {prob*100:.2f}%")
171
- st.progress(float(prob))
172
-
173
- # Show care recommendations
174
- if predicted_class in plant_care_tips:
175
- st.subheader("🌱 Care Recommendations")
176
-
177
- col_short, col_long = st.columns(2)
178
- with col_short:
179
- st.write("Immediate Actions:")
180
- for tip in plant_care_tips[predicted_class]["short_term"]:
181
- st.write(f"• {tip}")
182
- with col_long:
183
- st.write("Long-term Prevention:")
184
- for tip in plant_care_tips[predicted_class]["long_term"]:
185
- st.write(f"• {tip}")
186
-
187
- except Exception as e:
188
- st.error(f"Error processing image: {str(e)}")
189
-
190
- with tab2:
191
- st.header("Model Architecture")
192
- st.write("""
193
- This classifier uses a ResNet50 architecture with transfer learning:
194
- - Pre-trained on ImageNet
195
- - Fine-tuned on plant disease dataset
196
- - 2 disease classes
197
- - Input size: 224x224 pixels
198
- """)
199
-
200
- st.subheader("Performance Metrics")
201
- metrics_df = pd.DataFrame({
202
- 'Metric': ['Training Accuracy', 'Validation Accuracy', 'Number of Parameters'],
203
- 'Value': ['99.0%', '100%', '23.5M']
204
- })
205
- st.table(metrics_df)
206
 
207
- # Sidebar
208
- st.sidebar.title("Recent Predictions")
209
- if st.session_state.history:
210
- for item in reversed(st.session_state.history[-5:]):
211
- st.sidebar.write(f"Time: {item['timestamp']}")
212
- st.sidebar.write(f"Prediction: {item['prediction'].replace('_', ' ')}")
213
- st.sidebar.write(f"Confidence: {item['confidence']*100:.2f}%")
214
- st.sidebar.divider()
215
 
216
- if __name__ == "__main__":
217
- main()
 
 
1
  import streamlit as st
2
  import torch
3
  from torchvision import transforms
4
  from PIL import Image
 
 
 
 
5
  import time
6
  from datetime import datetime
7
  import pandas as pd
8
+ import torch.nn as nn
9
+ import torchvision.models as models
10
+
11
+ # Define model directly in app
12
+ class PlantDiseaseModel(nn.Module):
13
+ def __init__(self, num_classes=2):
14
+ super(PlantDiseaseModel, self).__init__()
15
+ self.model = models.resnet50(pretrained=True)
16
+ num_ftrs = self.model.fc.in_features
17
+ self.model.fc = nn.Linear(num_ftrs, num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def forward(self, x):
20
+ return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Simplified paths
23
+ MODEL_PATH = 'best_model.pth'
 
 
 
 
 
 
24
 
25
+ # [Rest of your original code...]