new-plants / app.py
ahmedsaber's picture
Update app.py
bb878a2 verified
from flask import Flask, request, jsonify
from flask_cors import CORS
import tensorflow as tf
import numpy as np
from io import BytesIO
from PIL import Image
app = Flask(__name__)
CORS(app)
# Load models for each plant
models = {
'potato': tf.keras.models.load_model('./potato_model.h5'),
'tomato': tf.keras.models.load_model('./tomato_model.h5'),
'grape': tf.keras.models.load_model('./grape_model.h5'),
'corn': tf.keras.models.load_model('./corn_model.h5'),
'pepper': tf.keras.models.load_model('./pepper_model.h5')
}
# Function to load and preprocess image
def load_image(file):
img = Image.open(BytesIO(file))
img = img.resize((224, 224))
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Convert single image to a batch
img_array = img_array / 255.0 # Normalize the image as done in preprocessing
return img_array
# Class labels for each plant model
class_labels = {
'potato': ["Early_Blight", "Healthy", "Late_Blight"],
'tomato': ["Bacterial_spot", "Early_blight", "Late_blight", "Leaf_Mold", "Septoria_leaf_spot", "Spider_mites Two-spotted_spider_mite", "Target_Spot", "_Yellow_Leaf_Curl_Virus", "_mosaic_virus", "healthy"],
'grape': ["Black Rot", "ESCA", "Healthy", "Leaf Blight"],
'corn': ["Blight", "Common_Rust", "Gray_Leaf_Spot", "Healthy"],
'pepper': ["Bacterial_spot", "healthy"]
}
# Prediction function
def predict(file, plant):
img = load_image(file)
model = models[plant]
predictions = model.predict(img)
predicted_class_idx = np.argmax(predictions, axis=1)[0]
confidence = predictions[0][predicted_class_idx]
return class_labels[plant][predicted_class_idx], confidence
# Check if the file is an allowed image type
def allowed_file(filename):
allowed_extensions = {'png', 'jpg', 'jpeg'}
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
# Routes for each plant
@app.route('/predict/potato', methods=['POST'])
def predict_potato():
return predict_route('potato')
@app.route('/predict/tomato', methods=['POST'])
def predict_tomato():
return predict_route('tomato')
@app.route('/predict/grape', methods=['POST'])
def predict_grape():
return predict_route('grape')
@app.route('/predict/corn', methods=['POST'])
def predict_corn():
return predict_route('corn')
@app.route('/predict/pepper', methods=['POST'])
def predict_pepper():
return predict_route('pepper')
# Common function to handle prediction route
def predict_route(plant):
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if not allowed_file(file.filename):
return jsonify({"error": "Unsupported file type"}), 400
try:
predicted_class, confidence = predict(file.read(), plant)
return jsonify({"predicted_class": predicted_class, "confidence": confidence})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run the Flask app
#if __name__ == '__main__':
# app.run(debug=True)