File size: 3,322 Bytes
8c474ce
 
 
018c2a2
40003bf
2436faa
018c2a2
 
2436faa
40003bf
018c2a2
 
 
 
 
 
 
 
 
 
 
 
2436faa
 
8c474ce
40003bf
 
018c2a2
40003bf
 
018c2a2
 
 
40003bf
 
018c2a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40003bf
 
018c2a2
40003bf
018c2a2
 
 
 
 
 
 
 
 
 
 
 
40003bf
018c2a2
 
40003bf
018c2a2
40003bf
35170ca
 
 
018c2a2
40003bf
 
018c2a2
40003bf
018c2a2
 
 
 
 
35170ca
018c2a2
 
 
8c474ce
40003bf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
import numpy as np

# Custom model class (replace with your actual architecture)
class PlantDiseaseClassifier(torch.nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        # Example architecture - REPLACE WITH YOUR ACTUAL MODEL
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(16, 32, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Flatten(),
            torch.nn.Linear(32*56*56, num_classes)  # Adjust input dimensions
        )
    
    def forward(self, x):
        return self.model(x)

@st.cache_resource
def load_model():
    model = PlantDiseaseClassifier(num_classes=2)  # Update with your class count
    try:
        model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
        st.success("Model loaded successfully!")
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
    return model

def predict(image, model, class_names):
    """Run prediction and return top class"""
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    input_tensor = transform(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        top_prob, top_class = torch.topk(probabilities, 1)
        
    return class_names[top_class.item()], top_prob.item()

def main():
    st.title("🌱 Plant Disease Classifier")
    
    # Update with your actual class names and care tips
    CLASS_NAMES = {
        0: "Healthy",
        1: "Late Blight",
        2: "Powdery Mildew"  # Add all your classes
    }
    
    CARE_TIPS = {
        "Healthy": ["Continue regular watering", "Monitor plant growth"],
        "Late Blight": ["Remove infected leaves", "Apply fungicide"],
        "Powdery Mildew": ["Improve air circulation", "Apply sulfur spray"]
    }
    
    model = load_model()
    uploaded_file = st.file_uploader("Upload plant image", type=["jpg", "png", "jpeg"])
    
    if uploaded_file and model is not None:
        image = Image.open(uploaded_file).convert("RGB")
        col1, col2 = st.columns(2)
        
        with col1:
            st.image(image, caption="Uploaded Image", use_column_width=True)
        
        with st.spinner("Analyzing..."):
            predicted_class, confidence = predict(image, model, CLASS_NAMES)
            
            with col2:
                if "healthy" in predicted_class.lower():
                    st.success(f"Prediction: {predicted_class} ({confidence*100:.1f}%)")
                else:
                    st.error(f"Prediction: {predicted_class} ({confidence*100:.1f}%)")
                
                st.subheader("Care Recommendations")
                for tip in CARE_TIPS.get(predicted_class, ["No specific recommendations"]):
                    st.write(f"• {tip}")

if __name__ == "__main__":
    main()