|
|
|
|
|
import gradio as gr |
|
|
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras.models import Model, load_model |
|
|
|
import numpy as np |
|
|
|
import cv2 |
|
|
|
from PIL import Image |
|
|
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as mpatches |
|
|
|
from pathlib import Path |
|
|
|
current_directory_path = Path(__file__).parent.resolve() |
|
object_detection_model_path = current_directory_path + "/" + "carla-image-segmentation-model.h5" |
|
lane_detection_model_path = current_directory_path + "/" + "lane-detection-for-carla-model.h5" |
|
|
|
label_map_object = {0: 'Unlabeled', 1: 'Building', 2: 'Fence', 3: 'Other', |
|
4: 'Pedestrian', 5: 'Pole', 6: 'RoadLine', 7: 'Road', 8: 'SideWalk', |
|
9: 'Vegetation', 10: 'Vehicles', 11: 'Wall', 12: 'TrafficSign'} |
|
|
|
lane_label_map = {0: 'Unlabeled', 1: 'Left Lane', 2: 'Right Lane'} |
|
|
|
|
|
object_detection_model = load_model(object_detection_model_path) |
|
|
|
|
|
lane_detection_model = load_model(lane_detection_model_path) |
|
|
|
|
|
def create_mask(object_detection_model, lane_detection_model, image): |
|
|
|
|
|
image = tf.io.read_file(image.name) |
|
image = tf.image.decode_png(image, channels=3) |
|
image = tf.image.convert_image_dtype(image, tf.float32) |
|
tensor = tf.image.resize(image, (256, 256), method='nearest') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_tensor = tf.expand_dims(tensor, axis=0) |
|
|
|
|
|
|
|
pred_masks_object_detect = object_detection_model.predict(input_tensor) |
|
pred_masks_object_detect = tf.expand_dims(tf.argmax(pred_masks_object_detect, axis=-1), axis=-1) |
|
pred_masks_object_detect = np.array(pred_masks_object_detect) |
|
|
|
|
|
pred_masks_lane_detect = lane_detection_model.predict(input_tensor) |
|
pred_masks_lane_detect = tf.expand_dims(tf.argmax(pred_masks_lane_detect, axis=-1), axis=-1) |
|
pred_masks_lane_detect = np.array(pred_masks_lane_detect) |
|
|
|
return pred_masks_object_detect, pred_masks_lane_detect |
|
|
|
|
|
def segment_object(image): |
|
pred_masks_object_detect, pred_masks_lane_detect = create_mask(object_detection_model, lane_detection_model, image) |
|
|
|
|
|
|
|
used_classes_object = np.unique(pred_masks_object_detect[0]) |
|
used_classes_lane = np.unique(pred_masks_lane_detect[0]) |
|
|
|
fig_object = plt.figure() |
|
im = plt.imshow(tf.keras.preprocessing.image.array_to_img(pred_masks_object_detect[0])) |
|
patches_1 = [mpatches.Patch(color=im.cmap(im.norm(int(cls))), label="{}".format(label_map_object[int(cls)])) for cls in used_classes_object] |
|
|
|
plt.legend(handles=patches_1, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) |
|
plt.axis("off") |
|
|
|
fig_lane = plt.figure() |
|
im = plt.imshow(tf.keras.preprocessing.image.array_to_img(pred_masks_lane_detect[0])) |
|
patches_1 = [mpatches.Patch(color=im.cmap(im.norm(int(cls))), label="{}".format(lane_label_map[int(cls)])) for cls in used_classes_lane] |
|
|
|
plt.legend(handles=patches_1, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) |
|
plt.axis("off") |
|
|
|
return fig_object |
|
|
|
|
|
webcam = gr.inputs.Image(shape=(800, 600), source="upload", type='file') |
|
|
|
webapp = gr.interface.Interface(fn=segment_object, inputs=webcam, outputs="plot") |
|
|
|
webapp.launch(debug=True) |
|
|