Spaces:
Running
Running
from flask import Flask, render_template, request, redirect, url_for, jsonify | |
import cv2 | |
import numpy as np | |
from tensorflow.lite.python.interpreter import Interpreter | |
import os | |
# Define paths to your model and label files | |
MODEL_PATH = "detect.tflite" | |
LABEL_PATH = "labelmap.txt" | |
# Function to load the TFLite model and labels | |
def load_model(): | |
interpreter = Interpreter(model_path=MODEL_PATH) | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
height = input_details[0]['shape'][1] | |
width = input_details[0]['shape'][2] | |
with open(LABEL_PATH, 'r') as f: | |
labels = [line.strip() for line in f.readlines()] | |
print(f"Model loaded. Input shape: {input_details[0]['shape']}") | |
return interpreter, input_details, output_details, height, width, labels | |
# Function to preprocess the image for the model | |
def preprocess_image(image, input_details, height, width): | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_resized = cv2.resize(image_rgb, (width, height)) | |
input_data = np.expand_dims(image_resized, axis=0) | |
if input_details[0]['dtype'] == np.float32: | |
input_data = (np.float32(input_data) - 127.5) / 127.5 | |
print(f"Image preprocessed: shape {input_data.shape}, dtype {input_data.dtype}") | |
return input_data | |
# Function to perform object detection and draw bounding boxes | |
def detect_objects(image, interpreter, input_details, output_details, labels): | |
input_data = preprocess_image(image, input_details, height, width) | |
interpreter.set_tensor(input_details[0]['index'], input_data) | |
interpreter.invoke() | |
boxes = interpreter.get_tensor(output_details[1]['index'])[0] # bounding box coordinates | |
classes = interpreter.get_tensor(output_details[3]['index'])[0] # class index | |
scores = interpreter.get_tensor(output_details[0]['index'])[0] # confidence scores | |
print(f"Detections: {len(scores)} objects detected") | |
for i in range(len(scores)): | |
if scores[i] > 0.1: # confidence threshold | |
ymin, xmin, ymax, xmax = boxes[i] | |
ymin = int(max(1, ymin * image.shape[0])) | |
xmin = int(max(1, xmin * image.shape[1])) | |
ymax = int(min(image.shape[0], ymax * image.shape[0])) | |
xmax = int(min(image.shape[1], xmax * image.shape[1])) | |
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) | |
label = f'{labels[int(classes[i])]}: {scores[i] * 100:.2f}%' | |
cv2.putText(image, label, (xmin, ymin - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
print(f"Object {i}: {label} at [{xmin}, {ymin}, {xmax}, {ymax}]") | |
return image | |
# Initialize the Flask app | |
app = Flask(__name__, static_folder='static') | |
# Load the TFLite model and labels | |
interpreter, input_details, output_details, height, width, labels = load_model() | |
def upload_and_detect(): | |
if request.method == 'POST': | |
if 'file' not in request.files: | |
print("No file part in the request") | |
return redirect(request.url) | |
file = request.files['file'] | |
if file.filename == '': | |
print("No selected file") | |
return redirect(request.url) | |
# Read the image file | |
image = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR) | |
if image is None: | |
print("Failed to read image") | |
return redirect(request.url) | |
print(f"Image uploaded: {file.filename}, shape: {image.shape}") | |
# Perform object detection | |
processed_image = detect_objects(image, interpreter, input_details, output_details, labels) | |
# Ensure the static directory exists | |
if not os.path.exists(app.static_folder): | |
os.makedirs(app.static_folder) | |
# Save processed image | |
save_path = os.path.join(app.static_folder, 'detected.jpg') | |
cv2.imwrite(save_path, processed_image) | |
print(f"Processed image saved at: {save_path}") | |
# Send back the path to the processed image | |
return jsonify({'image_url': url_for('static', filename='detected.jpg')}) | |
return render_template('index.html') | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=8000) | |