File size: 4,086 Bytes
a06267d
 
 
 
 
 
 
 
 
 
010e04f
a06267d
 
 
 
 
 
 
 
 
b0fa478
 
a06267d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# ! pip install gradio

import gradio as gr

import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras.models import Model, load_model

import numpy as np

# import cv2

from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from pathlib import Path

current_directory_path = Path(__file__).parent.resolve()
object_detection_model_path = current_directory_path / "carla-image-segmentation-model.h5"
lane_detection_model_path = current_directory_path / "lane-detection-for-carla-model.h5"

label_map_object = {0: 'Unlabeled', 1: 'Building', 2: 'Fence', 3: 'Other', 
                             4: 'Pedestrian', 5: 'Pole', 6: 'RoadLine', 7: 'Road', 8: 'SideWalk',
                             9: 'Vegetation', 10: 'Vehicles', 11: 'Wall', 12: 'TrafficSign'}

lane_label_map = {0: 'Unlabeled', 1: 'Left Lane', 2: 'Right Lane'}

# Load the object detection model
object_detection_model = load_model(object_detection_model_path)

# Load the lane detection model
lane_detection_model = load_model(lane_detection_model_path)


def create_mask(object_detection_model, lane_detection_model, image):
    # tensor = tf.convert_to_tensor(image, dtype=tf.float32)

    image = tf.io.read_file(image.name)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    tensor = tf.image.resize(image, (256, 256), method='nearest')

    # convert to tensor (specify 3 channels explicitly since png files contains additional alpha channel)
    # set the dtypes to align with pytorch for comparison since it will use uint8 by default
    # tensor = tf.io.decode_image(image_tensor, channels=3, dtype=tf.float32)

    # resize tensor to 224 x 224
    # tensor = tf.image.resize(tensor, [256, 256])

    # add another dimension at the front to get NHWC shape
    input_tensor = tf.expand_dims(tensor, axis=0)

    # with mp_selfie.SelfieSegmentation(model_selection=0) as model:
    # Create Masks for with Object Detection Model  
    pred_masks_object_detect = object_detection_model.predict(input_tensor)
    pred_masks_object_detect = tf.expand_dims(tf.argmax(pred_masks_object_detect, axis=-1), axis=-1)
    pred_masks_object_detect = np.array(pred_masks_object_detect)

    # Create Masks for with Lane Detection Model  
    pred_masks_lane_detect = lane_detection_model.predict(input_tensor)
    pred_masks_lane_detect = tf.expand_dims(tf.argmax(pred_masks_lane_detect, axis=-1), axis=-1)
    pred_masks_lane_detect = np.array(pred_masks_lane_detect)
    
    return pred_masks_object_detect, pred_masks_lane_detect


def segment_object(image): 
    pred_masks_object_detect, pred_masks_lane_detect = create_mask(object_detection_model, lane_detection_model, image)

    # image = cv2.resize(image, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)

    used_classes_object = np.unique(pred_masks_object_detect[0])
    used_classes_lane = np.unique(pred_masks_lane_detect[0])

    fig_object = plt.figure()
    im = plt.imshow(tf.keras.preprocessing.image.array_to_img(pred_masks_object_detect[0]))            
    patches_1 = [mpatches.Patch(color=im.cmap(im.norm(int(cls))), label="{}".format(label_map_object[int(cls)])) for cls in used_classes_object]
    # put those patched as legend-handles into the legend
    plt.legend(handles=patches_1, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.axis("off")   

    fig_lane = plt.figure()
    im = plt.imshow(tf.keras.preprocessing.image.array_to_img(pred_masks_lane_detect[0]))            
    patches_1 = [mpatches.Patch(color=im.cmap(im.norm(int(cls))), label="{}".format(lane_label_map[int(cls)])) for cls in used_classes_lane]
    # put those patched as legend-handles into the legend
    plt.legend(handles=patches_1, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.axis("off")   

    return fig_object


webcam = gr.inputs.Image(shape=(800, 600), source="upload", type='file') #upload

webapp = gr.interface.Interface(fn=segment_object, inputs=webcam, outputs="plot") #, live=False

webapp.launch(debug=True)