trafficlight / app.py
mattb512's picture
removing samples
a3a37f1
raw history blame
No virus
4.83 kB
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image, ImageDraw
import numpy as np
from torch import nn
import gradio as gr
import os
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Segformer_inference_notebook.ipynb
def cityscapes_palette():
"""Cityscapes palette for external use."""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
[0, 0, 230], [119, 11, 32]]
def cityscapes_classes():
"""Cityscapes class names for external use."""
return [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'
]
def annotation(image:ImageDraw, color_seg:np.array):
assert image.size == (1024, 1024)
assert color_seg.shape == (1024, 1024, 3)
blocks = 4 # 4x4 sub grid
step_size = 256 # sub square edge size
draw = ImageDraw.Draw(image)
sub_square_xy = [(x,y) for x in range(0, blocks * step_size, step_size) for y in range(0, blocks * step_size, step_size)]
# print(f"{sub_square_xy=}")
for (x,y) in sub_square_xy:
reduced_seg = color_seg.sum(axis=2) # collapsing all colors into 1024 x 1024
# print(f"{reduced_seg.shape=}")
sub_square_seg = reduced_seg[ y:y+step_size, x:x+step_size]
# print(f"{sub_square_seg.shape=}, {sub_square_seg.sum()}")
if (sub_square_seg.sum() > 100000):
print("light found at square ", x, y)
draw.rectangle([(x, y), (x + step_size, y + step_size)], outline=128, width=3)
def call(image): #nparray
resized = Image.fromarray(image).resize((1024,1024))
resized_image = np.array(resized)
print(f"{np.array(resized_image).shape=}") # 1024, 1024, 3
# resized_image = Image.fromarray(resized_image_np)
# print(f"{resized_image=}")
inputs = feature_extractor(images=resized_image, return_tensors="pt")
outputs = model(**inputs)
print(f"{outputs.logits.shape=}") # shape (batch_size, num_labels, height/4, width/4) -> 3, 19, 256 ,256
# print(f"{logits}")
# First, rescale logits to original image size
interpolated_logits = nn.functional.interpolate(
outputs.logits,
size=[1024, 1024], #resized_image.size[::-1], # (height, width)
mode='bilinear',
align_corners=False)
print(f"{interpolated_logits.shape=}, {outputs.logits.shape=}") # 1, 19, 1024, 1024
# Second, apply argmax on the class dimension
seg = interpolated_logits.argmax(dim=1)[0]
print(f"{seg.shape=}")
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
print(f"{color_seg.shape=}")
for label, color in enumerate(cityscapes_palette()):
if (label == 6): color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
print(f"{color_seg.shape=}")
# Show image + mask
img = np.array(resized_image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
out_im_file = Image.fromarray(img)
annotation(out_im_file, color_seg)
return out_im_file
# original_image = Image.open("./examples/1.jpg")
# print(f"{np.array(original_image).shape=}") # eg 729, 1000, 3
# out = call(original_image)
# out.save("out2.jpeg")
title = "Traffic Light Detector"
description = "Experiment traffic light detection to evaluate the value of captcha security controls"
iface = gr.Interface(fn=call,
inputs="image",
outputs="image",
title=title,
description=description,
examples=[
os.path.join(os.path.dirname(__file__), "examples/1.jpg"),
os.path.join(os.path.dirname(__file__), "examples/2.jpg"),
# os.path.join(os.path.dirname(__file__), "examples/3.jpg"),
# os.path.join(os.path.dirname(__file__), "examples/4.jpg"),
# os.path.join(os.path.dirname(__file__), "examples/5.jpg"),
# os.path.join(os.path.dirname(__file__), "examples/6.jpg"),
],
thumbnail="thumbnail.webp")
iface.launch()