import albumentations as A import base64 import cv2 import gradio as gr import inspect import io import numpy as np import os from dataclasses import dataclass from loguru import logger from copy import deepcopy from functools import wraps from PIL import Image, ImageDraw from typing import get_type_hints, Optional from pydantic_core._pydantic_core import ValidationError # from mixpanel import Mixpanel from utils import is_not_supported_transform # Some constants for Albumentations PositionType = A.PadIfNeeded.PositionType # MIXPANEL_TOKEN = os.getenv("MIXPANEL_TOKEN") # mp = Mixpanel(MIXPANEL_TOKEN) HEADER = f"""

A lbumentations Demo ({A.__version__})

Documentation   GitHub Repository

""" DEFAULT_TRANSFORM = "Rotate" NO_OPERATION_TRANFORM = "NoOp" DEFAULT_IMAGE_PATH = "images/doctor.webp" DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH)) DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0] DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1] DEFAULT_BOXES = [ [265, 121, 326, 177], # Mask [192, 169, 401, 395], # Coverall ] mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]] pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]] arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]] DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints BASE64_DEFAULT_MASKS = [ { "label": "Coverall", # light green color "color": (144, 238, 144), "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==", }, { "label": "Mask", # light blue color "color": (173, 216, 230), "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC", }, ] # Get all the transforms from the albumentations library transforms_map = { name: cls for name, cls in vars(A).items() if ( inspect.isclass(cls) and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform)) and not is_not_supported_transform(cls) ) } transforms_map.pop("DualTransform", None) transforms_map.pop("ImageOnlyTransform", None) transforms_map.pop("ReferenceBasedTransform", None) transforms_map.pop("ToFloat", None) transforms_map.pop("Normalize", None) transforms_keys = list(sorted(transforms_map.keys())) # Decode the masks for mask in BASE64_DEFAULT_MASKS: mask["mask"] = np.array( Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L") ) @dataclass class RequestParams: user_ip: str transform_name: Optional[str] def track_event(event_name, user_id="unknown", properties=None): if properties is None: properties = {} #mp.track(user_id, event_name, properties) logger.info(f"Event tracked: {event_name} - {properties}") def get_params(request: gr.Request) -> RequestParams: """Parse input request parameters.""" ip = request.client.host transform_name = request.query_params.get("transform", None) params = RequestParams(user_ip=ip, transform_name=transform_name) track_event("app_opened", user_id=params.user_ip, properties={"transform_name": params.transform_name}) return params def run_with_retry(compose): @wraps(compose) def wrapper(*args, **kwargs): processors = deepcopy(compose.processors) for _ in range(4): try: result = compose(*args, **kwargs) break except NotImplementedError as e: print(f"Caught NotImplementedError: {e}") if "bbox" in str(e): kwargs.pop("bboxes", None) kwargs.pop("category_id", None) compose.processors.pop("bboxes") if "keypoint" in str(e): kwargs.pop("keypoints", None) compose.processors.pop("keypoints") if "mask" in str(e): kwargs.pop("mask", None) except (ValueError, ValidationError) as e: raise gr.Error(str(e)) except Exception as e: compose.processors = processors raise e compose.processors = processors return result return wrapper def draw_boxes(image, boxes, color=(255, 0, 0), thickness=1) -> np.ndarray: """Draw boxes with PIL.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) for box in boxes: x_min, y_min, x_max, y_max = box draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness) return np.array(pil_image) def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2): """Draw keypoints with PIL.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) for keypoint in keypoints: x, y = keypoint draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) return np.array(pil_image) def get_rgb_mask(masks): """Get the RGB mask from the binary mask.""" rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8) for data in masks: mask = data["mask"] rgb_mask[mask > 0] = np.array(data["color"]) return rgb_mask def draw_mask(image, mask): """Draw the mask on the image.""" image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0) return image_with_mask def draw_not_implemented_image(image: np.ndarray, annotation_type: str): """Draw the image with a text. In the middle.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) # align in the centerm, and make bigger font text = f'Transform NOT working with "{annotation_type.upper()}" annotations.' length = draw.textlength(text) draw.text( (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2), text, fill=(255, 0, 0), align="center", ) return np.array(pil_image) def get_formatted_signature(function_or_class, indentation=4): signature = inspect.signature(function_or_class) type_hints = get_type_hints(function_or_class) args = [] for param in signature.parameters.values(): if param.name == "p": str_param = "p=1.0," elif param.default == inspect.Parameter.empty: str_param = f"{param.name}=," else: if isinstance(param.default, str): str_param = f'{param.name}="{param.default}",' else: str_param = f"{param.name}={param.default}," annotation = type_hints.get(param.name, param.annotation) if isinstance(param.annotation, type): str_param += f" # {param.annotation.__name__}" else: str_annotation = str(annotation).replace("typing.", "") str_param += f" # {str_annotation}" str_param = "\n" + " " * indentation + str_param args.append(str_param) result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")" return result def get_formatted_transform(transform_name): track_event("transform_selected", properties={"transform_name": transform_name}) transform = transforms_map[transform_name] return f"A.{transform.__name__}{get_formatted_signature(transform)}" def get_formatted_transform_docs(transform_name): transform = transforms_map[transform_name] return transform.__doc__.strip("\n") def update_augmented_images(image, code): if "=," in code: raise gr.Error("You have to fill in some parameters to apply transform!") try: augmentation = eval(code) except ValidationError as e: raise gr.Error(str(e)) track_event("transform_applied", properties={"transform_name": augmentation.__class__.__name__, "code": code}) compose = A.Compose( [augmentation], bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]), keypoint_params=A.KeypointParams(format="xy"), ) compose = run_with_retry(compose) # to prevent NotImplementedError keypoints = DEFAULT_KEYPOINTS bboxes = DEFAULT_BOXES mask = get_rgb_mask(BASE64_DEFAULT_MASKS) augmented = compose( image=image, mask=mask, keypoints=keypoints, bboxes=bboxes, category_id=range(len(bboxes)), ) image = augmented["image"] mask = augmented.get("mask", None) bboxes = augmented.get("bboxes", None) keypoints = augmented.get("keypoints", None) # Draw the augmented images (or replace by placeholder if not implemented) if mask is not None: image_with_mask = draw_mask(image.copy(), mask) else: image_with_mask = draw_not_implemented_image(image.copy(), "mask") if bboxes is not None: image_with_bboxes = draw_boxes(image.copy(), bboxes) else: image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes") if keypoints is not None: image_with_keypoints = draw_keypoints(image.copy(), keypoints) else: image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints") return [ (image_with_mask, "Mask"), (image_with_bboxes, "Boxes"), (image_with_keypoints, "Keypoints"), ] def update_image_info(image): h, w = image.shape[:2] dtype = image.dtype max_, min_ = image.max(), image.min() return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}" def update_code_and_docs(select): code = get_formatted_transform(select) docs = get_formatted_transform_docs(select) return code, docs def update_code_and_docs_on_start(url_params: gr.Request): params = get_params(url_params) if params.transform_name is not None and params.transform_name not in transforms_map: gr.Warning(f"Sorry, `{params.transform_name}` transform is not supported at the moment :(") transform_name = NO_OPERATION_TRANFORM elif params.transform_name in transforms_map: transform_name = params.transform_name else: transform_name = DEFAULT_TRANSFORM return gr.update(value=transform_name) with gr.Blocks() as demo: gr.Markdown(HEADER) with gr.Row(): with gr.Column(): with gr.Group(): # gr.Markdown( # (" " * 4) + \ # "If a component is loading on start, please, try to refresh the page a few times. [Working on fix...]" # ) select = gr.Dropdown( label="Select a transformation", choices=transforms_keys, value=DEFAULT_TRANSFORM, type="value", interactive=True, ) with gr.Accordion("Documentation (click to expand)", open=False): docs = gr.TextArea( get_formatted_transform_docs(DEFAULT_TRANSFORM), show_label=False, interactive=False, ) code = gr.Code( label="Code", language="python", value=get_formatted_transform(DEFAULT_TRANSFORM), interactive=True, lines=5, ) info = gr.TextArea( value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)", show_label=False, lines=1, max_lines=1, ) button = gr.Button("Apply!") image = gr.Image( value=DEFAULT_IMAGE_PATH, type="numpy", height=500, width=300, sources=[], ) with gr.Row(): augmented_image = gr.Gallery( value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"), rows=1, columns=3, show_label=False, ) select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs]) button.click( fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image] ) demo.load( update_code_and_docs_on_start, inputs=None, outputs=[select], queue=False ) if __name__ == "__main__": demo.launch()