Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pathlib import Path | |
| import shutil | |
| import os | |
| import torch | |
| from plant_app import Plant_Agent | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Initialize the Plant Agent | |
| device = torch.device('cpu') | |
| model_path = 'plant-disease-model.pth' | |
| label_path = 'class_labels.json' | |
| agent = Plant_Agent(model_path, label_path, device) | |
| # Permanent directory for uploaded images | |
| PERMANENT_DIR = "uploaded_images" | |
| os.makedirs(PERMANENT_DIR, exist_ok=True) | |
| # Allowed extensions | |
| ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png"} | |
| def verify_image_extension(file_name: str): | |
| """Verify that the file has an allowed extension.""" | |
| extension = file_name.split(".")[-1].lower() | |
| if extension not in ALLOWED_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=400, detail=f"Invalid file type: {extension}. Only JPG and PNG are allowed." | |
| ) | |
| return extension | |
| def save_file_permanently(file: UploadFile, directory: str) -> Path: | |
| """Save the uploaded file to a permanent directory.""" | |
| file_path = Path(directory) / file.filename | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| return file_path | |
| async def predict_disease(file: UploadFile = File(...)): | |
| """Predict disease from the uploaded image.""" | |
| # Verify the file extension | |
| verify_image_extension(file.filename) | |
| # Save the file permanently | |
| saved_path = save_file_permanently(file, PERMANENT_DIR) | |
| try: | |
| answer = agent.predict_disease(str(saved_path)) | |
| return {"response": answer} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def predict_specie(file: UploadFile = File(...)): | |
| """Predict species from the uploaded image.""" | |
| # Verify the file extension | |
| verify_image_extension(file.filename) | |
| # Save the file permanently | |
| saved_path = save_file_permanently(file, PERMANENT_DIR) | |
| # Perform prediction | |
| try: | |
| answer = agent.predict_specie(str(saved_path)) | |
| return {"response": answer} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |