Spaces:
Runtime error
Runtime error
| import datetime | |
| import json | |
| import uuid | |
| from typing import List | |
| import boto3 | |
| import matplotlib | |
| import matplotlib.patches as patches | |
| import matplotlib.pyplot as plt | |
| import numpy.typing as npt | |
| import streamlit as st | |
| client = boto3.client('s3') | |
| def plot_img_with_rects( | |
| img: npt.ArrayLike, boxes: List[List], threshold: float = 0.5, coef: int = 400 | |
| ) -> matplotlib.figure.Figure: | |
| """ | |
| Plot image with rectangles. | |
| Args: | |
| img: image as a numpy array | |
| boxes: the list of the bboxes | |
| threshold: threshold for bbox probability | |
| coef: coefficient to multiply images. Can be changed when the original image is a different size | |
| Returns: | |
| image with bboxes | |
| """ | |
| fig, ax = plt.subplots(1, figsize=(4, 4)) | |
| # Display the image | |
| ax.imshow(img) | |
| # Create a Rectangle patch | |
| for _, rect in enumerate(b for b in boxes if b[1] > threshold): | |
| label, _, xc, yc, w, h = rect | |
| xc, yc, w, h = xc * coef, yc * coef, w * coef, h * coef | |
| # the coordinates from center-based to left top corner | |
| x = xc - w / 2 | |
| y = yc - h / 2 | |
| label = int(label) | |
| label = label if label != 10 else 'penis' | |
| label = label if label != 11 else 'junk' | |
| rect = [x, y, x + w, y + h] | |
| rect_ = patches.Rectangle( | |
| (rect[0], rect[1]), rect[2] - rect[0], rect[3] - rect[1], linewidth=2, edgecolor='blue', facecolor='none' | |
| ) | |
| plt.text(rect[2], rect[1], f'{label}', color='blue') | |
| # Add the patch to the Axes | |
| ax.add_patch(rect_) | |
| return fig | |
| def save_object_to_s3(filename, s3_filename): | |
| client.upload_file(filename, 'digitdrawdetect', s3_filename) | |
| def save_image(image: npt.ArrayLike, pred: List[List]) -> str: | |
| """ | |
| Save the image and upload the image with bboxes to s3. | |
| Args: | |
| image: np.array with image | |
| pred: bboxes | |
| Returns: | |
| image name | |
| """ | |
| # create a figure and save it | |
| fig, ax = plt.subplots(1, figsize=(4, 4)) | |
| ax.imshow(image) | |
| file_name = str(datetime.datetime.today().date()) + str(uuid.uuid1()) | |
| fig.savefig(f'{file_name}.png') | |
| # dump bboxes in a local file | |
| with open(f'{file_name}.json', 'w') as f: | |
| json.dump({f'{file_name}.png': pred}, f) | |
| # upload the image and the bboxes to s3. | |
| save_object_to_s3(f'{file_name}.png', f'images/{file_name}.png') | |
| save_object_to_s3(f'{file_name}.json', f'labels/{file_name}.json') | |
| return file_name | |