File size: 2,264 Bytes
4aebec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from pathlib import Path
import shutil
import os
import torch
from plant_app import Plant_Agent

# Initialize FastAPI app
app = FastAPI()

# Initialize the Plant Agent
device = torch.device('cpu')
model_path = 'plant-disease-model.pth'
label_path = 'class_labels.json'
agent = Plant_Agent(model_path, label_path, device)

# Permanent directory for uploaded images
PERMANENT_DIR = "uploaded_images"
os.makedirs(PERMANENT_DIR, exist_ok=True)

# Allowed extensions
ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png"}

def verify_image_extension(file_name: str):
    """Verify that the file has an allowed extension."""
    extension = file_name.split(".")[-1].lower()
    if extension not in ALLOWED_EXTENSIONS:
        raise HTTPException(
            status_code=400, detail=f"Invalid file type: {extension}. Only JPG and PNG are allowed."
        )
    return extension

def save_file_permanently(file: UploadFile, directory: str) -> Path:
    """Save the uploaded file to a permanent directory."""
    file_path = Path(directory) / file.filename
    with open(file_path, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)
    return file_path

@app.post("/predict_disease")
async def predict_disease(file: UploadFile = File(...)):
    """Predict disease from the uploaded image."""
    # Verify the file extension
    verify_image_extension(file.filename)

    # Save the file permanently
    saved_path = save_file_permanently(file, PERMANENT_DIR)

    try:
        answer = agent.predict_disease(str(saved_path))
        return {"response": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict_specie")
async def predict_specie(file: UploadFile = File(...)):
    """Predict species from the uploaded image."""
    # Verify the file extension
    verify_image_extension(file.filename)

    # Save the file permanently
    saved_path = save_file_permanently(file, PERMANENT_DIR)

    # Perform prediction
    try:
        answer = agent.predict_specie(str(saved_path))
        return {"response": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))