File size: 3,657 Bytes
530f5ca 6bf5d09 530f5ca ed885a1 d64c5a6 6bf5d09 8b93c35 d64c5a6 530f5ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
from flask import Flask, render_template, request
from flask_cors import cross_origin, CORS
import tensorflow.lite as lite
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from utils.make_upload_dir import make_upload_dir
from tensorflow.keras.utils import load_img, save_img, img_to_array
from numpy import array, round
from utils.delete_file import delete_file
from requests import get
from functools import lru_cache
app = Flask(__name__, template_folder='templates', static_folder='static')
CORS(app)
CLASS_INDICES = {'Bean': 0,
'Bitter_Gourd': 1,
'Bottle_Gourd': 2,
'Brinjal': 3,
'Broccoli': 4,
'Cabbage': 5,
'Capsicum': 6,
'Carrot': 7,
'Cauliflower': 8,
'Cucumber': 9,
'Papaya': 10,
'Potato': 11,
'Pumpkin': 12,
'Radish': 13,
'Tomato': 14}
@lru_cache(maxsize=1)
def load_model(model_path):
interpreter = lite.Interpreter(model_path=model_path)
print('loaded model')
return interpreter
interpreter = load_model(model_path=os.path.join("models", "vegetable_classification_model_mnet.tflite"))
def predict(test_image):
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
test_image = test_image.astype(input_details["dtype"])
interpreter.set_tensor(input_details["index"], test_image)
interpreter.invoke()
output = interpreter.get_tensor(output_details["index"])[0]
prediction = [*CLASS_INDICES.keys()][output.argmax()]
probability = round(output.max() * 100, 2)
result = f'{prediction} ({probability}%)'
return result
@app.route("/")
@cross_origin()
def index():
"""
displays the index.html page
"""
display_image = os.path.join("static", "white.png")
delete_file(os.path.join('.', 'static', 'uploads'))
delete_file(os.path.join('.', 'static', 'output', 'display.jpg'))
return render_template("index.html", display_image=display_image)
@app.route("/", methods=["POST"])
@cross_origin()
def file_prediction():
upload_file_path = ""
result = ''
error = ""
try:
upload_file_path = os.path.join('.', 'static', 'uploads')
make_upload_dir(upload_file_path)
upload_filename = "input.jpg"
upload_file_path = os.path.join(upload_file_path, upload_filename)
file_url = request.form['fileinput']
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
r = get(file_url, headers=headers)
with open(upload_file_path, "wb") as file:
file.write(r.content)
test_image = load_img(upload_file_path, color_mode="rgb", target_size=(224, 224))
test_image = img_to_array(test_image)
test_image = preprocess_input(test_image)
test_image = array([test_image])
display_image = os.path.join('.', 'static', 'output', 'display.jpg')
save_img(path=display_image, x=test_image[0])
result = predict(test_image)
delete_file(upload_file_path)
except Exception as e:
# raise
result = ''
error = e
display_image = os.path.join("static", "white.png")
delete_file(upload_file_path)
return render_template("index.html", result=result, error=error, display_image=display_image)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5002)
|