Spaces:
Sleeping
Sleeping
File size: 2,264 Bytes
4aebec9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from pathlib import Path
import shutil
import os
import torch
from plant_app import Plant_Agent
# Initialize FastAPI app
app = FastAPI()
# Initialize the Plant Agent
device = torch.device('cpu')
model_path = 'plant-disease-model.pth'
label_path = 'class_labels.json'
agent = Plant_Agent(model_path, label_path, device)
# Permanent directory for uploaded images
PERMANENT_DIR = "uploaded_images"
os.makedirs(PERMANENT_DIR, exist_ok=True)
# Allowed extensions
ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png"}
def verify_image_extension(file_name: str):
"""Verify that the file has an allowed extension."""
extension = file_name.split(".")[-1].lower()
if extension not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400, detail=f"Invalid file type: {extension}. Only JPG and PNG are allowed."
)
return extension
def save_file_permanently(file: UploadFile, directory: str) -> Path:
"""Save the uploaded file to a permanent directory."""
file_path = Path(directory) / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return file_path
@app.post("/predict_disease")
async def predict_disease(file: UploadFile = File(...)):
"""Predict disease from the uploaded image."""
# Verify the file extension
verify_image_extension(file.filename)
# Save the file permanently
saved_path = save_file_permanently(file, PERMANENT_DIR)
try:
answer = agent.predict_disease(str(saved_path))
return {"response": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict_specie")
async def predict_specie(file: UploadFile = File(...)):
"""Predict species from the uploaded image."""
# Verify the file extension
verify_image_extension(file.filename)
# Save the file permanently
saved_path = save_file_permanently(file, PERMANENT_DIR)
# Perform prediction
try:
answer = agent.predict_specie(str(saved_path))
return {"response": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|