File size: 2,513 Bytes
9ab711c
 
 
 
b683920
9ab711c
b683920
 
 
cf41825
9ab711c
 
 
b683920
 
 
cf41825
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99d3d67
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ab711c
 
 
 
 
 
b683920
9ab711c
 
 
 
 
b683920
 
9ab711c
 
b683920
9ab711c
 
 
 
 
 
 
 
 
 
 
 
 
b683920
9ab711c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import datetime
import json
import uuid
from typing import List

import boto3
import matplotlib
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy.typing as npt
import streamlit as st

client = boto3.client('s3')


def plot_img_with_rects(
    img: npt.ArrayLike, boxes: List[List], threshold: float = 0.5, coef: int = 400
) -> matplotlib.figure.Figure:
    """
    Plot image with rectangles.

    Args:
        img: image as a numpy array
        boxes: the list of the bboxes
        threshold: threshold for bbox probability
        coef: coefficient to multiply images. Can be changed when the original image is a different size

    Returns:
        image with bboxes
    """
    fig, ax = plt.subplots(1, figsize=(4, 4))

    # Display the image
    ax.imshow(img)

    # Create a Rectangle patch
    for _, rect in enumerate(b for b in boxes if b[1] > threshold):
        label, _, xc, yc, w, h = rect
        xc, yc, w, h = xc * coef, yc * coef, w * coef, h * coef
        # the coordinates from center-based to left top corner
        x = xc - w / 2
        y = yc - h / 2
        label = int(label)
        label = label if label != 10 else 'penis'
        label = label if label != 11 else 'junk'
        rect = [x, y, x + w, y + h]

        rect_ = patches.Rectangle(
            (rect[0], rect[1]), rect[2] - rect[0], rect[3] - rect[1], linewidth=2, edgecolor='blue', facecolor='none'
        )
        plt.text(rect[2], rect[1], f'{label}', color='blue')
        # Add the patch to the Axes
        ax.add_patch(rect_)
    return fig


def save_object_to_s3(filename, s3_filename):
    client.upload_file(filename, 'digitdrawdetect', s3_filename)


@st.cache(show_spinner=False)
def save_image(image: npt.ArrayLike, pred: List[List]) -> str:
    """
    Save the image and upload the image with bboxes to s3.

    Args:
        image: np.array with image
        pred: bboxes

    Returns:
        image name

    """
    # create a figure and save it
    fig, ax = plt.subplots(1, figsize=(4, 4))
    ax.imshow(image)
    file_name = str(datetime.datetime.today().date()) + str(uuid.uuid1())
    fig.savefig(f'{file_name}.png')

    # dump bboxes in a local file
    with open(f'{file_name}.json', 'w') as f:
        json.dump({f'{file_name}.png': pred}, f)

    # upload the image and the bboxes to s3.
    save_object_to_s3(f'{file_name}.png', f'images/{file_name}.png')
    save_object_to_s3(f'{file_name}.json', f'labels/{file_name}.json')

    return file_name