Dhiryashil's picture
Upload 3 files
1932a55 verified
"""
Farm Disease Detection API - Gradio Interface
ViT and specialized models for plant disease detection
"""
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import json
import base64
import io
import time
from typing import List, Dict, Any
# Import models
try:
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import AutoImageProcessor, AutoModelForImageClassification
MODELS_AVAILABLE = True
except ImportError:
MODELS_AVAILABLE = False
class DiseaseDetectionAPI:
def __init__(self):
self.models = {}
self.processors = {}
self.model_configs = {
"vit_base_224": "google/vit-base-patch16-224",
"vit_base_384": "google/vit-base-patch16-384",
"plant_disease": "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
}
# Disease treatment database
self.treatments = {
"corn_blight": "Apply fungicide containing azoxystrobin or propiconazole",
"tomato_late_blight": "Remove affected leaves, apply copper-based fungicide",
"wheat_rust": "Apply triazole fungicides, improve air circulation",
"potato_early_blight": "Use preventive fungicide spray, improve drainage",
"apple_scab": "Apply sulfur-based fungicide, prune for air circulation",
"healthy": "Continue current care routine, monitor regularly"
}
if MODELS_AVAILABLE:
self.load_models()
def load_models(self):
"""Load disease detection models"""
for model_key, model_name in self.model_configs.items():
try:
print(f"Loading {model_name}...")
if "vit" in model_key:
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
else:
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
self.processors[model_key] = processor
self.models[model_key] = model
print(f"βœ… {model_name} loaded successfully")
except Exception as e:
print(f"❌ Failed to load {model_name}: {e}")
def analyze_plant_health(self, image: Image.Image, model_key: str = "plant_disease") -> Dict[str, Any]:
"""Analyze plant health and detect diseases"""
if not MODELS_AVAILABLE or model_key not in self.models:
return {"error": "Model not available"}
start_time = time.time()
try:
# Preprocess image
processor = self.processors[model_key]
model = self.models[model_key]
inputs = processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Get predictions
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
top_predictions = torch.topk(predictions, 5)
# Format results
diseases_detected = []
for score, idx in zip(top_predictions.values[0], top_predictions.indices[0]):
confidence = float(score)
if confidence > 0.1: # Threshold for relevance
disease_name = model.config.id2label[idx.item()]
treatment = self.treatments.get(disease_name.lower(), "Consult agricultural expert")
diseases_detected.append({
"disease": disease_name,
"confidence": confidence,
"treatment": treatment
})
# Calculate health score (higher if healthy classes dominate)
primary_disease = diseases_detected[0]
health_score = 1.0 - primary_disease["confidence"] if "healthy" not in primary_disease["disease"].lower() else primary_disease["confidence"]
# Generate recommendations
recommendations = self.generate_recommendations(diseases_detected, health_score)
processing_time = time.time() - start_time
return {
"health_score": round(float(health_score), 2),
"primary_disease": {
"name": primary_disease["disease"],
"confidence": round(primary_disease["confidence"], 2),
"severity": self.get_severity(primary_disease["confidence"])
},
"diseases_detected": diseases_detected[:3], # Top 3
"recommendations": recommendations,
"processing_time": round(processing_time, 2),
"model_used": model_key
}
except Exception as e:
return {"error": str(e)}
def get_severity(self, confidence: float) -> str:
"""Determine disease severity based on confidence"""
if confidence > 0.8:
return "severe"
elif confidence > 0.5:
return "moderate"
elif confidence > 0.3:
return "mild"
else:
return "minimal"
def generate_recommendations(self, diseases: List[Dict], health_score: float) -> List[str]:
"""Generate treatment recommendations"""
recommendations = []
if health_score > 0.8:
recommendations.extend([
"Plant appears healthy - continue current care",
"Monitor regularly for early disease signs",
"Maintain proper watering and nutrition"
])
elif health_score > 0.5:
recommendations.extend([
"Early intervention recommended",
"Improve growing conditions",
"Consider preventive treatments"
])
else:
recommendations.extend([
"Immediate treatment required",
"Isolate affected plants if possible",
"Consult agricultural specialist"
])
# Add specific disease treatments
for disease in diseases[:2]:
if disease["confidence"] > 0.3:
recommendations.append(disease["treatment"])
return recommendations[:5] # Limit to 5 recommendations
# Initialize API
api = DiseaseDetectionAPI()
def predict_disease(image, model_choice):
"""Gradio prediction function"""
if image is None:
return None, "Please upload an image"
# Convert to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Run analysis
results = api.analyze_plant_health(image, model_choice)
if "error" in results:
return None, f"Error: {results['error']}"
# Create visualization
annotated_image = image.copy()
# Format results text
health_score = results['health_score']
primary_disease = results['primary_disease']
health_color = "🟒" if health_score > 0.7 else "🟑" if health_score > 0.4 else "πŸ”΄"
results_text = f"""
🩺 **Plant Health Analysis**
{health_color} **Health Score**: {health_score:.1%}
🦠 **Primary Issue**: {primary_disease['name']} ({primary_disease['confidence']:.1%} confidence)
⚠️ **Severity**: {primary_disease['severity'].title()}
**πŸ”¬ Detected Issues**:
"""
for i, disease in enumerate(results["diseases_detected"], 1):
results_text += f"\n{i}. **{disease['disease']}** ({disease['confidence']:.1%})"
results_text += f"\n\n**πŸ’‘ Recommendations**:"
for i, rec in enumerate(results["recommendations"], 1):
results_text += f"\n{i}. {rec}"
return annotated_image, results_text
# Gradio Interface
with gr.Blocks(title="🩺 Farm Disease Detection API") as app:
gr.Markdown("# 🩺 Farm Disease Detection API")
gr.Markdown("AI-powered plant disease detection and health assessment")
with gr.Tab("🌱 Plant Analysis"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Plant Image")
model_choice = gr.Dropdown(
choices=["plant_disease", "vit_base_224", "vit_base_384"],
value="plant_disease",
label="Select Model"
)
analyze_btn = gr.Button("πŸ” Analyze Plant Health", variant="primary")
with gr.Column():
output_image = gr.Image(label="Plant Image")
results_text = gr.Textbox(label="Health Analysis", lines=15)
analyze_btn.click(
predict_disease,
inputs=[image_input, model_choice],
outputs=[output_image, results_text]
)
with gr.Tab("πŸ“‘ API Documentation"):
gr.Markdown("""
## πŸš€ API Endpoint
**POST** `/api/predict`
### Request Format
```json
{
"data": ["<base64_image>", "<model_choice>"]
}
```
### Model Options
- **plant_disease**: Specialized plant disease model (recommended)
- **vit_base_224**: Fast Vision Transformer
- **vit_base_384**: High resolution Vision Transformer
### Response Format
```json
{
"health_score": 0.75,
"primary_disease": {
"name": "corn_blight",
"confidence": 0.92,
"severity": "moderate"
},
"diseases_detected": [...],
"recommendations": [...],
"processing_time": 1.2
}
```
""")
if __name__ == "__main__":
app.launch()