File size: 8,573 Bytes
758f3f5
 
 
 
 
88b9db3
758f3f5
 
 
 
 
 
 
 
 
 
 
ffe6f85
 
 
 
758f3f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601b2b5
758f3f5
ffe6f85
758f3f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffe6f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758f3f5
 
 
ffe6f85
 
758f3f5
 
 
 
 
 
 
 
 
 
88b9db3
758f3f5
 
bbd60c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffe6f85
 
758f3f5
 
984e213
758f3f5
 
 
 
ffe6f85
 
758f3f5
 
 
 
 
 
 
 
 
88b9db3
 
 
 
 
 
 
 
 
 
202e823
 
 
8bcb05e
202e823
 
8bcb05e
 
202e823
 
8bcb05e
 
 
 
 
 
 
 
758f3f5
 
 
ffe6f85
 
758f3f5
ffe6f85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
import os
import requests
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
import shutil
import uvicorn

app = FastAPI()

# Directory where models are stored
MODEL_DIRECTORY = "dsanet_models"

# Temporary directory for uploaded files
TMP_DIR = os.getenv("TMP_DIR", "/app/temp")
os.makedirs(TMP_DIR, exist_ok=True)  # Ensure the temp directory exists

# Plant disease class names
plant_disease_dict = {
    "Rice": ['Blight', 'Brown_Spots'],
    "Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
               'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
               'Tomato___Spider_mites Two-spotted_spider_mite',
               'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
               'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'],
    "Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'],
    "Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'],
    "Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'],
    "Peach": ['Peach___Bacterial_spot', 'Peach___healthy'],
    "Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)',
              'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'],
    "Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'],
    "Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'],
    "Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
             'Corn___Northern_Leaf_Blight', 'Corn___healthy'],"Blueberry":["okk"]
}

# Custom Self-Attention Layer
@tf.keras.utils.register_keras_serializable()
class SelfAttention(Layer):
    def __init__(self, reduction_ratio=2, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        n_channels = input_shape[-1] // self.reduction_ratio
        self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
        self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
        self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
        super(SelfAttention, self).build(input_shape)

    def call(self, inputs):
        query = self.query_conv(inputs)
        key = self.key_conv(inputs)
        value = self.value_conv(inputs)

        # Calculate attention scores
        attention_scores = tf.matmul(query, key, transpose_b=True)
        attention_scores = Softmax(axis=1)(attention_scores)

        # Apply attention to values
        attended_value = tf.matmul(attention_scores, value)
        concatenated_output = Concatenate(axis=-1)([inputs, attended_value])
        return concatenated_output

    def get_config(self):
        config = super(SelfAttention, self).get_config()
        config.update({"reduction_ratio": self.reduction_ratio})
        return config


# **Load all models into memory at startup**
loaded_models = {}

def load_all_models():
    """
    Load all models from the `dsanet_models` directory at startup.
    """
    global loaded_models
    for plant_name in plant_disease_dict.keys():
        model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras")

        if os.path.isfile(model_path):
            try:
                if plant_name == "Rice":
                    loaded_models[plant_name] = load_model(model_path)  # Load normally
                else:
                    loaded_models[plant_name] = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
                print(f"✅ Model for {plant_name} loaded successfully!")
            except Exception as e:
                print(f"❌ Error loading model '{plant_name}': {e}")
        else:
            print(f"⚠ Warning: Model file '{model_path}' not found!")

# Load models at startup
load_all_models()


@app.get("/health")
async def api_health_check():
    return JSONResponse(content={"status": "Service is running"})


@app.post("/predict/{plant_name}")
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
    """
    API endpoint to predict plant disease from an uploaded image.

    Args:
        plant_name (str): The plant type (must match a key in `plant_disease_dict`).
        file (UploadFile): The image file uploaded by the user.

    Returns:
        JSON response with the predicted class and additional details from an external API.
    """
    # Ensure the plant name is valid
    if len(plant_disease_dict.get(plant_name, [])) == 1:
        single_disease = plant_disease_dict[plant_name][0]  # Get the only class available
    
        # 🔥 Fetch external data directly
        try:
            response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{single_disease}")
            external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"}
        except Exception as e:
            external_data = {"error": str(e)}
    
        return JSONResponse(content={
            "plantName": external_data.get("plantName", plant_name),
            "botanicalName": external_data.get("botanicalName", "Unknown"),
            "diseaseDesc": {
                "diseaseName": external_data.get("diseaseDesc", {}).get("diseaseName", single_disease),
                "symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"),
                "diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")
            },
            "diseaseRemedyList": [
                {
                    "title": remedy.get("title", "Unknown"),
                    "diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"),
                    "diseaseRemedy": remedy.get("diseaseRemedy", "Not Available")
                } for remedy in external_data.get("diseaseRemedyList", [])
            ]
        })
    if plant_name not in loaded_models:
        raise HTTPException(status_code=400, detail=f"Invalid plant name or model not loaded: {plant_name}")

    # Save uploaded file temporarily
    temp_path = os.path.join(TMP_DIR, file.filename)
    with open(temp_path, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    try:
        # Retrieve the preloaded model
        model = loaded_models[plant_name]

        # Load and preprocess the image
        img = image.load_img(temp_path, target_size=(224, 224))
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)  # Expand dimensions for model input
        img_array = img_array / 255.0  # Normalize

        # Make prediction
        prediction = model.predict(img_array)
        class_label = plant_disease_dict[plant_name][np.argmax(prediction)]

        # Fetch additional data from external API
        try:
            response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}")
            external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"}
        except Exception as e:
            external_data = {"error": str(e)}

        return JSONResponse(content={
    "plantName": external_data.get("plantName", plant_name),
    "botanicalName": external_data.get("botanicalName", "Unknown"),
    "diseaseDesc": {"diseaseName":external_data.get("diseaseDesc", {}).get("diseaseName", class_label),
    "symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"),
    "diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")},
    "diseaseRemedyList": [
        {
            "title": remedy.get("title", "Unknown"),
            "diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"),
            "diseaseRemedy": remedy.get("diseaseRemedy", "Not Available")
        } for remedy in external_data.get("diseaseRemedyList", [])
    ]
})
        # return JSONResponse(content={
        #     "plant": plant_name,
        #     "predicted_disease": class_label,
        #     "external_data": external_data
        # })
    finally:
        # Clean up temporary file
        os.remove(temp_path)


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)