Spaces:
Runtime error
Runtime error
# AUTOGENERATED! DO NOT EDIT! File to edit: telecom_object_detection.ipynb. | |
# %% auto 0 | |
__all__ = ['title', 'css', 'urls', 'imgs', 'img_samples', 'fig2img', 'custom_vision_detect_objects', 'set_example_url', | |
'set_example_image', 'detect_objects'] | |
# %% telecom_object_detection.ipynb 2 | |
import gradio as gr | |
import numpy as np | |
import os | |
import io | |
import requests | |
from pathlib import Path | |
# %% telecom_object_detection.ipynb 6 | |
from azure.cognitiveservices.vision.customvision.prediction import CustomVisionPredictionClient | |
from msrest.authentication import ApiKeyCredentials | |
from matplotlib import pyplot as plt | |
from PIL import Image, ImageDraw, ImageFont | |
from dotenv import load_dotenv | |
# %% telecom_object_detection.ipynb 11 | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def custom_vision_detect_objects(image_file: Path): | |
dpi = 100 | |
# Get Configuration Settings | |
load_dotenv() | |
prediction_endpoint = os.getenv('PredictionEndpoint') | |
prediction_key = os.getenv('PredictionKey') | |
project_id = os.getenv('ProjectID') | |
model_name = os.getenv('ModelName') | |
# Authenticate a client for the training API | |
credentials = ApiKeyCredentials(in_headers={"Prediction-key": prediction_key}) | |
prediction_client = CustomVisionPredictionClient(endpoint=prediction_endpoint, credentials=credentials) | |
# Load image and get height, width and channels | |
#image_file = 'produce.jpg' | |
print('Detecting objects in', image_file) | |
image = Image.open(image_file) | |
h, w, ch = np.array(image).shape | |
# Detect objects in the test image | |
with open(image_file, mode="rb") as image_data: | |
results = prediction_client.detect_image(project_id, model_name, image_data) | |
# Create a figure for the results | |
fig = plt.figure(figsize=(w/dpi, h/dpi)) | |
plt.axis('off') | |
# Display the image with boxes around each detected object | |
draw = ImageDraw.Draw(image) | |
lineWidth = int(w/800) | |
color = 'cyan' | |
for prediction in results.predictions: | |
# Only show objects with a > 50% probability | |
if (prediction.probability*100) > 50: | |
# Box coordinates and dimensions are proportional - convert to absolutes | |
left = prediction.bounding_box.left * w | |
top = prediction.bounding_box.top * h | |
height = prediction.bounding_box.height * h | |
width = prediction.bounding_box.width * w | |
# Draw the box | |
points = ((left,top), (left+width,top), (left+width,top+height), (left,top+height), (left,top)) | |
draw.line(points, fill=color, width=lineWidth) | |
# Add the tag name and probability | |
#plt.annotate(prediction.tag_name + ": {0:.2f}%".format(prediction.probability * 100),(left,top), backgroundcolor=color) | |
plt.annotate( | |
prediction.tag_name + ": {0:.0f}%".format(prediction.probability * 100), | |
(left, top-1.372*h/dpi), | |
backgroundcolor=color, | |
fontsize=max(w/dpi, h/dpi), | |
fontfamily='monospace' | |
) | |
plt.imshow(image) | |
plt.tight_layout(pad=0) | |
return fig2img(fig) | |
outputfile = 'output.jpg' | |
fig.savefig(outputfile) | |
print('Resulabsts saved in ', outputfile) | |
# %% telecom_object_detection.ipynb 15 | |
title = """<h1 id="title">Telecom Object Detection with Azure Custom Vision</h1>""" | |
css = ''' | |
h1#title { | |
text-align: center; | |
} | |
''' | |
# %% telecom_object_detection.ipynb 16 | |
urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"] | |
imgs = [path.as_posix() for path in sorted(Path('images').rglob('*.jpg'))] | |
img_samples = [[path.as_posix()] for path in sorted(Path('images').rglob('*.jpg'))] | |
# %% telecom_object_detection.ipynb 17 | |
def set_example_url(example: list) -> dict: | |
return gr.Textbox.update(value=example[0]) | |
def set_example_image(example: list) -> dict: | |
return gr.Image.update(value=example[0]) | |
def detect_objects(image_input:Image): | |
#if validators.url(url_input): | |
# image = Image.open(requests.get(url_input, stream=True).raw) | |
#elif image_input: | |
# image = image_input | |
print(image_input) | |
print(image_input.size) | |
w, h = image_input.size | |
if max(w, h) > 1_200: | |
factor = 1_200 / max(w, h) | |
factor = 1 | |
size = (int(w*factor), int(h*factor)) | |
image_input = image_input.resize(size, resample=Image.Resampling.BILINEAR) | |
resized_image_path = "input_object_detection.jpg" | |
image_input.save(resized_image_path) | |
#return image_input | |
#return custom_vision_detect_objects(Path(filename[0])) | |
return custom_vision_detect_objects(resized_image_path) | |
# %% telecom_object_detection.ipynb 19 | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(title) | |
with gr.Tabs(): | |
with gr.TabItem("Image Upload"): | |
with gr.Row(): | |
image_input = gr.Image(type='pil') | |
image_output = gr.Image(shape=(650,650)) | |
with gr.Row(): | |
"""example_images = gr.Dataset(components=[img_input], | |
samples=[[path.as_posix()] for path in sorted(Path('images').rglob('*.jpg'))] | |
)""" | |
#example_images = gr.Examples(examples=imgs, inputs=image_input) | |
example_images = gr.Dataset(components=[image_input], samples=img_samples) | |
image_button = gr.Button("Detect") | |
with gr.TabItem("Image URL"): | |
with gr.Row(): | |
url_input = gr.Textbox(lines=2, label='Enter valid image URL here..') | |
img_output_from_url = gr.Image(shape=(650,650)) | |
with gr.Row(): | |
example_url = gr.Dataset(components=[url_input], samples=[[str(url)] for url in urls]) | |
url_button = gr.Button("Detect") | |
url_button.click(detect_objects, inputs=[url_input], outputs=img_output_from_url) | |
image_button.click(detect_objects, inputs=[image_input], outputs=image_output) | |
#image_button.click(detect_objects, inputs=[example_images], outputs=image_output) | |
example_url.click(fn=set_example_url, inputs=[example_url], outputs=[url_input]) | |
example_images.click(fn=set_example_image, inputs=[example_images], outputs=[image_input]) | |
demo.launch() | |