flask-app-ocr / main.py
WaiYanLinn
fixed
71aab31
import os
from flask import Flask, request, redirect, jsonify
import numpy as np
from flask import render_template
from asgiref.wsgi import WsgiToAsgi
import numpy as np
import cv2
from sklearn.preprocessing import LabelEncoder
import imutils
from imutils.contours import sort_contours
from keras.models import load_model
import warnings
from flask_cors import CORS
# Suppress specific TensorFlow and Keras warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, module="tensorflow")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="keras")
# Get the path to the directory containing this script
script_dir = os.path.dirname(os.path.abspath(__file__))
# Load the model using the relative path
model_path = os.path.join(script_dir, "./ocr_perfecto_experiment.h5")
model = load_model(model_path)
def test_pipeline(image_data):
img = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
height, width = img_gray.shape
half_width = round(width / 2)
half_height = round(height / 2)
img_gray = cv2.resize(img_gray, (half_width, half_height))
img_gray = cv2.GaussianBlur(img_gray, (5, 5), 0)
edged = cv2.Canny(img_gray, 30, 150)
dilated = cv2.dilate(edged.copy(), None, iterations=6)
normalized_image = cv2.normalize(dilated, None, 0, 255, cv2.NORM_MINMAX)
contours = cv2.findContours(
normalized_image.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
contours = imutils.grab_contours(contours)
contours = sort_contours(contours, method="left-to-right")[0]
labels = [
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"_",
"-",
"[",
"]",
"+",
"%",
]
real_labels = [
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"*",
"-",
"(",
")",
"+",
"/",
]
label_encoder = LabelEncoder()
label_class = label_encoder.fit_transform(labels)
results = []
for c in contours:
if cv2.contourArea(c) < 1000:
continue
(x, y, w, h) = cv2.boundingRect(c)
if 20 <= w:
roi = img_gray[y : y + h, x : x + w]
thresh = cv2.threshold(
roi, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU
)[1]
(th, tw) = thresh.shape
if tw > th:
thresh = imutils.resize(thresh, width=28)
if th > tw:
thresh = imutils.resize(thresh, height=28)
(th, tw) = thresh.shape
dx = int(max(0, 28 - tw) / 2.0)
dy = int(max(0, 28 - th) / 2.0)
padded = cv2.copyMakeBorder(
thresh,
top=dy,
bottom=dy,
left=dx,
right=dx,
borderType=cv2.BORDER_CONSTANT,
value=(0, 0, 0),
)
padded = cv2.resize(padded, (28, 28))
padded = np.array(padded)
padded = padded / 255.0
padded = np.expand_dims(padded, axis=0)
padded = np.expand_dims(padded, axis=-1)
pred = model.predict(padded)
pred = np.argmax(pred, axis=1)
results.append(real_labels[np.where(label_class == pred[0])[0][0]])
return results
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg"}
app = Flask(__name__, template_folder="./src/templates", static_folder="./src/public")
app.secret_key = "1234"
cors = CORS(app, resources={r"/*": {"origins": "*"}})
def allowed_file(filename):
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route("/")
def index():
return render_template("index.html")
@app.route("/api/photo-upload", methods=["POST"])
def upload_file():
try:
if "file" not in request.files:
raise ValueError("File not found in the request.")
file = request.files["file"]
if file.filename == "":
raise ValueError("Empty filename in the request.")
if file and allowed_file(file.filename):
image = file.read()
image_data = np.frombuffer(image, np.uint8)
results = test_pipeline(image_data)
return jsonify(results), 200
else:
raise ValueError("Invalid file type.")
except Exception as e:
return f"Error processing file: {str(e)}", 500
@app.route("/predict", methods=["POST"])
def predict():
try:
if "file" not in request.files:
raise ValueError("File not found in the request.")
file = request.files["file"]
if file.filename == "":
raise ValueError("Empty filename in the request.")
if file and allowed_file(file.filename):
image = file.read()
image_data = np.frombuffer(image, np.uint8)
results = test_pipeline(image_data)
return jsonify(results), 200
else:
raise ValueError("Invalid file type.")
except Exception as e:
return f"Error processing file: {str(e)}", 500
wsgi = WsgiToAsgi(app)
def create_app():
return app