person-reid / app.py
Ashu1803's picture
Update app.py
98c091b verified
raw
history blame contribute delete
No virus
8.5 kB
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from sklearn.metrics.pairwise import cosine_similarity
from filterpy.kalman import KalmanFilter
import gradio as gr
# Load the frozen inference graph
frozen_graph_path = "frozen_inference_graph.pb"
# Load the frozen TensorFlow model
with tf.io.gfile.GFile(frozen_graph_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
# Convert the frozen graph to a function
def wrap_frozen_graph(graph_def, inputs, outputs):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
return wrapped_import.prune(
tf.nest.map_structure(wrapped_import.graph.as_graph_element, inputs),
tf.nest.map_structure(wrapped_import.graph.as_graph_element, outputs))
# Define input and output tensors
inputs = ["image_tensor:0"]
outputs = ["detection_boxes:0", "detection_scores:0", "detection_classes:0", "num_detections:0"]
# Get the detection function
detection_fn = wrap_frozen_graph(graph_def, inputs, outputs)
# TensorFlow function for detection
@tf.function(input_signature=[tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8)])
def detect_objects(image):
return detection_fn(image)
# Load ResNet50 for feature extraction
resnet_model = ResNet50(weights="imagenet", include_top=False, pooling="avg")
# Initialize variables to store features and identities
person_features = []
person_identities = []
person_colors = {}
kalman_filters = {}
next_person_id = 1 # Starting unique ID for persons
# Function to generate unique colors based on person ID
def get_color(person_id):
np.random.seed(person_id) # Ensure color is unique for each person_id
color = tuple(np.random.randint(0, 256, size=3)) # Generates RGB tuple
return (int(color[0]), int(color[1]), int(color[2])) # Ensure the color is a tuple of ints
def extract_features(person_roi):
# Resize and preprocess the ROI for ResNet50 input
person_roi_resized = cv2.resize(person_roi, (224, 224))
person_roi_preprocessed = preprocess_input(person_roi_resized)
# Add batch dimension for ResNet50 input
input_tensor = np.expand_dims(person_roi_preprocessed, axis=0)
# Extract features using ResNet50
features = resnet_model.predict(input_tensor)
return features
def initialize_kalman_filter(bbox):
kf = KalmanFilter(dim_x=7, dim_z=4)
kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0]])
kf.R[2:, 2:] *= 10.
kf.P[4:, 4:] *= 1000.
kf.P *= 10.
kf.Q[-1, -1] *= 0.01
kf.Q[4:, 4:] *= 0.01
kf.x[:4] = bbox.reshape((4, 1))
return kf
def predict_bbox(kf):
kf.predict()
return kf.x[:4].reshape((4,))
def update_kalman_filter(kf, bbox):
kf.update(bbox.reshape((4, 1)))
return kf
def match_and_identify(features, bbox):
global next_person_id
# Flag to check if a match is found
matched = False
# Iterate over existing identities to check for matches
for idx, (feat, identity) in enumerate(zip(person_features, person_identities)):
# Compute cosine similarity between features
similarity = cosine_similarity(
np.array(feat).reshape(1, -1),
np.array(features).reshape(1, -1)
)[0][0]
# If similarity is above threshold, consider them as the same person
similarity_threshold = 0.7 # Adjust as needed
if similarity > similarity_threshold:
# Assign color if not already assigned
if identity in person_colors:
color = person_colors[identity]
else:
color = get_color(identity)
person_colors[identity] = color
# Update Kalman filter
kalman_filters[identity] = update_kalman_filter(kalman_filters[identity], bbox)
# Set matched flag to True
matched = True
return identity, color
# If no match found, add new identity
if not matched:
person_features.append(features)
person_identities.append(next_person_id)
color = get_color(next_person_id)
person_colors[next_person_id] = color
# Initialize Kalman filter
kalman_filters[next_person_id] = initialize_kalman_filter(bbox)
identity = next_person_id
next_person_id += 1
return identity, color
def process_image(image):
if image is None:
print("Input image is None")
return None
# Convert image to RGB if it's not
if len(image.shape) == 2: # Grayscale
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Ensure image is uint8
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8)
# Prepare the image tensor
image_np = np.array(image)
input_tensor = np.expand_dims(image_np, axis=0)
try:
# Run inference
detections = detect_objects(input_tensor)
# Extract output tensors and convert to numpy arrays
boxes = detections[0].numpy()[0]
scores = detections[1].numpy()[0]
classes = detections[2].numpy()[0]
num_detections = int(detections[3].numpy()[0])
print(f"Number of detections: {num_detections}")
# Filter detections for 'person' class
threshold = 0.3 # Adjust this threshold as needed
for i in range(num_detections):
class_id = int(classes[i])
score = scores[i]
box = boxes[i]
if class_id == 1 and score > threshold:
h, w, _ = image.shape
ymin, xmin, ymax, xmax = box
left, right, top, bottom = int(xmin * w), int(xmax * w), int(ymin * h), int(ymax * h)
# Extract person ROI
person_roi = image[top:bottom, left:right]
# Extract features
features = extract_features(person_roi)
# Predict bbox using Kalman filter
predicted_bbox = np.array([xmin, ymin, xmax, ymax])
# Match and identify
identity, color = match_and_identify(features, predicted_bbox)
# Draw bounding box
left, top, right, bottom = int(predicted_bbox[0] * w), int(predicted_bbox[1] * h), int(predicted_bbox[2] * w), int(predicted_bbox[3] * h)
cv2.rectangle(image, (left, top), (right, bottom), color, 2)
cv2.putText(image, f'Person {identity}', (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
print(f"Detected person {identity} at ({left}, {top}, {right}, {bottom})")
except Exception as e:
print(f"Error during processing: {str(e)}")
return image # Return original image if there's an error
return image
def gradio_interface(input_image):
if input_image is None:
print("Input image is None")
return None
# Convert PIL Image to numpy array if necessary
if hasattr(input_image, 'convert'):
input_image = np.array(input_image.convert('RGB'))
# Process the input image
output_image = process_image(input_image)
if output_image is None:
print("Output image is None")
return None
print(f"Output image shape: {output_image.shape}")
print(f"Output image dtype: {output_image.dtype}")
# Ensure the output is in the correct format for Gradio
if output_image.dtype != np.uint8:
output_image = (output_image * 255).astype(np.uint8)
return output_image
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(),
outputs=gr.Image(),
title="Person Detection and Tracking",
description="Upload an image to detect and track persons.",
)
# Launch the interface
iface.launch()