import streamlit as st import torch from torchvision import transforms from PIL import Image import numpy as np # Custom model class (replace with your actual architecture) class PlantDiseaseClassifier(torch.nn.Module): def __init__(self, num_classes=2): super().__init__() # Example architecture - REPLACE WITH YOUR ACTUAL MODEL self.model = torch.nn.Sequential( torch.nn.Conv2d(3, 16, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.Conv2d(16, 32, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.Flatten(), torch.nn.Linear(32*56*56, num_classes) # Adjust input dimensions ) def forward(self, x): return self.model(x) @st.cache_resource def load_model(): model = PlantDiseaseClassifier(num_classes=2) # Update with your class count try: model.load_state_dict(torch.load('best_model.pth', map_location='cpu')) st.success("Model loaded successfully!") except Exception as e: st.error(f"Error loading model: {str(e)}") return model def predict(image, model, class_names): """Run prediction and return top class""" transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) input_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) top_prob, top_class = torch.topk(probabilities, 1) return class_names[top_class.item()], top_prob.item() def main(): st.title("🌱 Plant Disease Classifier") # Update with your actual class names and care tips CLASS_NAMES = { 0: "Healthy", 1: "Late Blight", 2: "Powdery Mildew" # Add all your classes } CARE_TIPS = { "Healthy": ["Continue regular watering", "Monitor plant growth"], "Late Blight": ["Remove infected leaves", "Apply fungicide"], "Powdery Mildew": ["Improve air circulation", "Apply sulfur spray"] } model = load_model() uploaded_file = st.file_uploader("Upload plant image", type=["jpg", "png", "jpeg"]) if uploaded_file and model is not None: image = Image.open(uploaded_file).convert("RGB") col1, col2 = st.columns(2) with col1: st.image(image, caption="Uploaded Image", use_column_width=True) with st.spinner("Analyzing..."): predicted_class, confidence = predict(image, model, CLASS_NAMES) with col2: if "healthy" in predicted_class.lower(): st.success(f"Prediction: {predicted_class} ({confidence*100:.1f}%)") else: st.error(f"Prediction: {predicted_class} ({confidence*100:.1f}%)") st.subheader("Care Recommendations") for tip in CARE_TIPS.get(predicted_class, ["No specific recommendations"]): st.write(f"• {tip}") if __name__ == "__main__": main()