|
from flask import Flask, render_template, request, jsonify |
|
import os |
|
import cv2 |
|
import numpy as np |
|
import tensorflow as tf |
|
from object_detection.utils import label_map_util |
|
from object_detection.utils import visualization_utils as viz_utils |
|
from object_detection.builders import model_builder |
|
from object_detection.utils import config_util |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
CUSTOM_MODEL_NAME = 'my_ssd_mobnet' |
|
CHECKPOINT_PATH = os.path.join('Tensorflow', 'workspace', 'models', CUSTOM_MODEL_NAME) |
|
LABELMAP_PATH = os.path.join('Tensorflow', 'workspace', 'annotations', 'label_map.pbtxt') |
|
MIN_SCORE_THRESH = 0.4 |
|
MAX_BOXES_TO_DRAW = 10 |
|
|
|
|
|
def load_model(): |
|
configs = config_util.get_configs_from_pipeline_file(os.path.join(CHECKPOINT_PATH, 'pipeline.config')) |
|
detection_model = model_builder.build(model_config=configs['model'], is_training=False) |
|
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model) |
|
ckpt.restore(os.path.join(CHECKPOINT_PATH, 'ckpt-7')).expect_partial() |
|
category_index = label_map_util.create_category_index_from_labelmap(LABELMAP_PATH) |
|
return detection_model, category_index |
|
|
|
|
|
@tf.function |
|
def detect_fn(image: tf.Tensor) -> tf.Tensor: |
|
image, shapes = detection_model.preprocess(image) |
|
prediction_dict = detection_model.predict(image, shapes) |
|
detections = detection_model.postprocess(prediction_dict, shapes) |
|
return detections |
|
|
|
|
|
@app.route('/detect', methods=['POST']) |
|
def detect(): |
|
try: |
|
file = request.files['image'] |
|
img = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR) |
|
image_np = np.array(img) |
|
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32) |
|
detections = detect_fn(input_tensor) |
|
|
|
return img_str |
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
@app.route('/', methods=['GET']) |
|
def index(): |
|
return render_template('index.html') |
|
|
|
if __name__ == "__main__": |
|
detection_model, category_index = load_model() |
|
app.run(debug=True) |