Spaces:
Runtime error
Runtime error
File size: 6,447 Bytes
5869455 1912fb8 d3a06f0 c8fa30f 1912fb8 5869455 1912fb8 5869455 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import cv2
import numpy as np
import os
import torch
import onnxruntime as ort
import time
from functools import wraps
import argparse
from PIL import Image
from io import BytesIO
import streamlit as st
# Parse command-line arguments
#parser = argparse.ArgumentParser()
#parser.add_argument("--mosaic", help="Enable mosaic processing mode", action="store_true")
#args = parser.parse_args()
#mosaic = args.mosaic # Set this based on your command line argument
# For streamlit use let's just set mosaic to "true", but I'm leavind the command-line arg here for anyone to use
mosaic = True
def center_crop(img, new_height, new_width):
height, width, _ = img.shape
start_x = width//2 - new_width//2
start_y = height//2 - new_height//2
return img[start_y:start_y+new_height, start_x:start_x+new_width]
def mosaic_crop(img, size):
height, width, _ = img.shape
padding_height = (size - height % size) % size
padding_width = (size - width % size) % size
padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0])
tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)]
return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width
def stitch_tiles(tiles, rows, cols, size):
return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0)
def timing_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
print(f"Function '{func.__name__}' took {duration:.6f} seconds")
return result
return wrapper
@timing_decorator
def process_image(session, img, colors, mosaic=False):
if not mosaic:
# Crop the center of the image to 416x416 pixels
img = center_crop(img, 416, 416)
blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False)
# Perform inference
output = session.run(None, {session.get_inputs()[0].name: blob})
# Assuming the output is a probability map where higher values indicate higher probability of a class
output_img = output[0].squeeze(0).transpose(1, 2, 0)
output_img = (output_img * 122).clip(0, 255).astype(np.uint8)
output_mask = output_img.max(axis=2)
output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8)
# Assign specific colors to the classes in the mask
for class_idx in np.unique(output_mask):
if class_idx in colors:
output_mask_color[output_mask == class_idx] = colors[class_idx]
# Mask for the transparent class
transparent_mask = (output_mask == 122)
# Convert the mask to a 3-channel image
transparent_mask = np.stack([transparent_mask]*3, axis=-1)
# Where the mask is True, set the output color image to the input image
output_mask_color[transparent_mask] = img[transparent_mask]
# Make the colorful mask semi-transparent
overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0)
return overlay
st.title("OpenLander ONNX app")
st.write("Upload an image to process with the ONNX OpenLander model!")
st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.")
models = {
"Embedded model better trained: DeeplabV3+, MobilenetV2, 416px resolution": "20230608_onnx_416_mbnv2_dl3/end2end.onnx",
"test model 24k: DV3+ MBv2, 416px": "test_24000.onnx",
"test model 48k: DV3+ MBv2, 416px": "test_48000.onnx",
"test model 72k: DV3+ MBv2, 416px": "test_72000.onnx",
"test model 96k: DV3+ MBv2, 416px": "test_96000.onnx",
"test model 120k: DV3+ MBv2, 416px": "test_120000.onnx",
"test model 144k: DV3+ MBv2, 416px": "test_144000.onnx",
"test model 168k: DV3+ MBv2, 416px": "test_168000.onnx",
"test model 192k: DV3+ MBv2, 416px": "test_192000.onnx",
"test model 216k: DV3+ MBv2, 416px": "test_216000.onnx",
"test model 240k: DV3+ MBv2, 416px": "test_240000.onnx"
}
# Create a Streamlit radio button to select the desired model
selected_model = st.radio("Select a model", list(models.keys()))
# set cuda = true if you have an NVIDIA GPU
cuda = torch.cuda.is_available()
if cuda:
print("We have a GPU!")
providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider']
# Get the selected model's path
model_path = models[selected_model]
session = ort.InferenceSession(model_path, providers=providers)
# Define colors for classes 0, 122 and 244
colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)} # Red, Black, Yellow
def load_image(uploaded_file):
try:
image = Image.open(uploaded_file)
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
except Exception as e:
st.write("Could not load image: ", e)
return None
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
if uploaded_file is not None:
img = load_image(uploaded_file)
if img.shape[2] == 4:
img = img[:, :, :3] # Drop the alpha channel if it exists
img_processed = None
if st.button('Process'):
with st.spinner('Processing...'):
start = time.time()
if mosaic:
tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416)
processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles]
overlay = stitch_tiles(processed_tiles, rows, cols, 416)
# Crop the padding back out
overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width]
img_processed = overlay
else:
img_processed = process_image(session, img, colors)
end = time.time()
st.write(f"Processing time: {end - start} seconds")
st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True)
if img_processed is not None:
st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True)
st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ")
|