Spaces:
Running
Running
File size: 8,573 Bytes
758f3f5 88b9db3 758f3f5 ffe6f85 758f3f5 601b2b5 758f3f5 ffe6f85 758f3f5 ffe6f85 758f3f5 ffe6f85 758f3f5 88b9db3 758f3f5 bbd60c8 ffe6f85 758f3f5 984e213 758f3f5 ffe6f85 758f3f5 88b9db3 202e823 8bcb05e 202e823 8bcb05e 202e823 8bcb05e 758f3f5 ffe6f85 758f3f5 ffe6f85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
import os
import requests
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
import shutil
import uvicorn
app = FastAPI()
# Directory where models are stored
MODEL_DIRECTORY = "dsanet_models"
# Temporary directory for uploaded files
TMP_DIR = os.getenv("TMP_DIR", "/app/temp")
os.makedirs(TMP_DIR, exist_ok=True) # Ensure the temp directory exists
# Plant disease class names
plant_disease_dict = {
"Rice": ['Blight', 'Brown_Spots'],
"Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'],
"Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'],
"Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'],
"Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'],
"Peach": ['Peach___Bacterial_spot', 'Peach___healthy'],
"Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'],
"Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'],
"Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'],
"Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
'Corn___Northern_Leaf_Blight', 'Corn___healthy'],"Blueberry":["okk"]
}
# Custom Self-Attention Layer
@tf.keras.utils.register_keras_serializable()
class SelfAttention(Layer):
def __init__(self, reduction_ratio=2, **kwargs):
super(SelfAttention, self).__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
n_channels = input_shape[-1] // self.reduction_ratio
self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
super(SelfAttention, self).build(input_shape)
def call(self, inputs):
query = self.query_conv(inputs)
key = self.key_conv(inputs)
value = self.value_conv(inputs)
# Calculate attention scores
attention_scores = tf.matmul(query, key, transpose_b=True)
attention_scores = Softmax(axis=1)(attention_scores)
# Apply attention to values
attended_value = tf.matmul(attention_scores, value)
concatenated_output = Concatenate(axis=-1)([inputs, attended_value])
return concatenated_output
def get_config(self):
config = super(SelfAttention, self).get_config()
config.update({"reduction_ratio": self.reduction_ratio})
return config
# **Load all models into memory at startup**
loaded_models = {}
def load_all_models():
"""
Load all models from the `dsanet_models` directory at startup.
"""
global loaded_models
for plant_name in plant_disease_dict.keys():
model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras")
if os.path.isfile(model_path):
try:
if plant_name == "Rice":
loaded_models[plant_name] = load_model(model_path) # Load normally
else:
loaded_models[plant_name] = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
print(f"✅ Model for {plant_name} loaded successfully!")
except Exception as e:
print(f"❌ Error loading model '{plant_name}': {e}")
else:
print(f"⚠ Warning: Model file '{model_path}' not found!")
# Load models at startup
load_all_models()
@app.get("/health")
async def api_health_check():
return JSONResponse(content={"status": "Service is running"})
@app.post("/predict/{plant_name}")
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
"""
API endpoint to predict plant disease from an uploaded image.
Args:
plant_name (str): The plant type (must match a key in `plant_disease_dict`).
file (UploadFile): The image file uploaded by the user.
Returns:
JSON response with the predicted class and additional details from an external API.
"""
# Ensure the plant name is valid
if len(plant_disease_dict.get(plant_name, [])) == 1:
single_disease = plant_disease_dict[plant_name][0] # Get the only class available
# 🔥 Fetch external data directly
try:
response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{single_disease}")
external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"}
except Exception as e:
external_data = {"error": str(e)}
return JSONResponse(content={
"plantName": external_data.get("plantName", plant_name),
"botanicalName": external_data.get("botanicalName", "Unknown"),
"diseaseDesc": {
"diseaseName": external_data.get("diseaseDesc", {}).get("diseaseName", single_disease),
"symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"),
"diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")
},
"diseaseRemedyList": [
{
"title": remedy.get("title", "Unknown"),
"diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"),
"diseaseRemedy": remedy.get("diseaseRemedy", "Not Available")
} for remedy in external_data.get("diseaseRemedyList", [])
]
})
if plant_name not in loaded_models:
raise HTTPException(status_code=400, detail=f"Invalid plant name or model not loaded: {plant_name}")
# Save uploaded file temporarily
temp_path = os.path.join(TMP_DIR, file.filename)
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Retrieve the preloaded model
model = loaded_models[plant_name]
# Load and preprocess the image
img = image.load_img(temp_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input
img_array = img_array / 255.0 # Normalize
# Make prediction
prediction = model.predict(img_array)
class_label = plant_disease_dict[plant_name][np.argmax(prediction)]
# Fetch additional data from external API
try:
response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}")
external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"}
except Exception as e:
external_data = {"error": str(e)}
return JSONResponse(content={
"plantName": external_data.get("plantName", plant_name),
"botanicalName": external_data.get("botanicalName", "Unknown"),
"diseaseDesc": {"diseaseName":external_data.get("diseaseDesc", {}).get("diseaseName", class_label),
"symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"),
"diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")},
"diseaseRemedyList": [
{
"title": remedy.get("title", "Unknown"),
"diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"),
"diseaseRemedy": remedy.get("diseaseRemedy", "Not Available")
} for remedy in external_data.get("diseaseRemedyList", [])
]
})
# return JSONResponse(content={
# "plant": plant_name,
# "predicted_disease": class_label,
# "external_data": external_data
# })
finally:
# Clean up temporary file
os.remove(temp_path)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|