Spaces:
Sleeping
Sleeping
File size: 3,322 Bytes
8c474ce 018c2a2 40003bf 2436faa 018c2a2 2436faa 40003bf 018c2a2 2436faa 8c474ce 40003bf 018c2a2 40003bf 018c2a2 40003bf 018c2a2 40003bf 018c2a2 40003bf 018c2a2 40003bf 018c2a2 40003bf 018c2a2 40003bf 35170ca 018c2a2 40003bf 018c2a2 40003bf 018c2a2 35170ca 018c2a2 8c474ce 40003bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
# Custom model class (replace with your actual architecture)
class PlantDiseaseClassifier(torch.nn.Module):
def __init__(self, num_classes=2):
super().__init__()
# Example architecture - REPLACE WITH YOUR ACTUAL MODEL
self.model = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(16, 32, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Flatten(),
torch.nn.Linear(32*56*56, num_classes) # Adjust input dimensions
)
def forward(self, x):
return self.model(x)
@st.cache_resource
def load_model():
model = PlantDiseaseClassifier(num_classes=2) # Update with your class count
try:
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return model
def predict(image, model, class_names):
"""Run prediction and return top class"""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_class = torch.topk(probabilities, 1)
return class_names[top_class.item()], top_prob.item()
def main():
st.title("🌱 Plant Disease Classifier")
# Update with your actual class names and care tips
CLASS_NAMES = {
0: "Healthy",
1: "Late Blight",
2: "Powdery Mildew" # Add all your classes
}
CARE_TIPS = {
"Healthy": ["Continue regular watering", "Monitor plant growth"],
"Late Blight": ["Remove infected leaves", "Apply fungicide"],
"Powdery Mildew": ["Improve air circulation", "Apply sulfur spray"]
}
model = load_model()
uploaded_file = st.file_uploader("Upload plant image", type=["jpg", "png", "jpeg"])
if uploaded_file and model is not None:
image = Image.open(uploaded_file).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner("Analyzing..."):
predicted_class, confidence = predict(image, model, CLASS_NAMES)
with col2:
if "healthy" in predicted_class.lower():
st.success(f"Prediction: {predicted_class} ({confidence*100:.1f}%)")
else:
st.error(f"Prediction: {predicted_class} ({confidence*100:.1f}%)")
st.subheader("Care Recommendations")
for tip in CARE_TIPS.get(predicted_class, ["No specific recommendations"]):
st.write(f"• {tip}")
if __name__ == "__main__":
main() |