from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from PIL import Image import torch import torch.nn.functional as F import torchvision.transforms as transforms import io from models import ResNet9 app = FastAPI(title="CropGuard - Plant Disease Detection") CLASS_NAMES = [ 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)__Powdery_mildew', 'Cherry(including_sour)__healthy', 'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)_Common_rust', 'Corn(maize)__Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape___Black_rot', 'Grape___Esca(Black_Measles)', 'Grape___Leaf_blight(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy' ] model = None def load_model(): global model if model is None: model = ResNet9(3, len(CLASS_NAMES)) state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu") model.load_state_dict(state_dict) model.eval() load_model() @app.post("/predict") async def predict(file: UploadFile = File(...)): try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor() ]) img_tensor = transform(image) if isinstance(img_tensor, torch.Tensor) and img_tensor.ndimension() == 3: img_tensor = img_tensor.unsqueeze(0) global model if model is None: load_model() if model is None: raise RuntimeError("Model failed to load.") with torch.no_grad(): outputs = model(img_tensor) probabilities = F.softmax(outputs[0], dim=0) top5_prob, top5_indices = torch.topk(probabilities, 5) results = {} for prob, idx in zip(top5_prob, top5_indices): class_name = CLASS_NAMES[int(idx.item())] clean_name = class_name.replace('___', ' - ').replace('_', ' ') results[clean_name] = float(prob) return JSONResponse(content={"predictions": results}) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) @app.get("/") def root(): return {"message": "CropGuard FastAPI is running. Use /predict to POST an image."}