File size: 3,104 Bytes
9ab711c
 
2c16c94
9ab711c
 
b683920
9ab711c
b683920
 
 
cf41825
9ab711c
2c16c94
9ab711c
cdb8c6a
 
2c16c94
 
 
 
 
 
 
 
 
 
 
 
cdb8c6a
 
b683920
 
 
cf41825
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99d3d67
b683920
 
 
 
 
 
ca67f90
 
b683920
 
 
 
 
 
 
 
 
 
 
9ab711c
 
 
 
f0facef
9ab711c
b683920
9ab711c
 
 
 
 
b683920
 
9ab711c
 
b683920
9ab711c
 
 
 
 
 
 
2c16c94
 
9ab711c
 
 
 
b683920
9ab711c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import datetime
import json
import os
import uuid
from typing import List

import boto3
import matplotlib
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy.typing as npt
import streamlit as st
import tomli

AWS_ACCESS_KEY_ID = ''
AWS_SECRET_ACCESS_KEY = ''
try:
    if st.secrets is not None:
        AWS_ACCESS_KEY_ID = st.secrets['AWS_ACCESS_KEY_ID']
        AWS_SECRET_ACCESS_KEY = st.secrets['AWS_SECRET_ACCESS_KEY']
except BaseException:
    pass

if os.path.exists('config.toml'):
    with open('config.toml', 'rb') as f:
        config = tomli.load(f)
        AWS_ACCESS_KEY_ID = config['AWS_ACCESS_KEY_ID']
        AWS_SECRET_ACCESS_KEY = config['AWS_SECRET_ACCESS_KEY']

client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)


def plot_img_with_rects(
    img: npt.ArrayLike, boxes: List[List], threshold: float = 0.5, coef: int = 400
) -> matplotlib.figure.Figure:
    """
    Plot image with rectangles.

    Args:
        img: image as a numpy array
        boxes: the list of the bboxes
        threshold: threshold for bbox probability
        coef: coefficient to multiply images. Can be changed when the original image is a different size

    Returns:
        image with bboxes
    """
    fig, ax = plt.subplots(1, figsize=(4, 4))

    # Display the image
    ax.imshow(img)

    # Create a Rectangle patch
    for _, rect in enumerate(b for b in boxes if b[1] > threshold):
        label, _, xc, yc, w, h = rect
        xc, yc, w, h = xc * coef, yc * coef, w * coef, h * coef
        # the coordinates from center-based to left top corner
        x = xc - w / 2
        y = yc - h / 2
        label = int(label)
        label = label if label != 10 else 'censored'
        label = label if label != 11 else 'other'
        rect = [x, y, x + w, y + h]

        rect_ = patches.Rectangle(
            (rect[0], rect[1]), rect[2] - rect[0], rect[3] - rect[1], linewidth=2, edgecolor='blue', facecolor='none'
        )
        plt.text(rect[2], rect[1], f'{label}', color='blue')
        # Add the patch to the Axes
        ax.add_patch(rect_)
    return fig


def save_object_to_s3(filename, s3_filename):
    client.upload_file(filename, 'digitdrawdetect', s3_filename)


@st.cache_data(show_spinner=False)
def save_image(image: npt.ArrayLike, pred: List[List]) -> str:
    """
    Save the image and upload the image with bboxes to s3.

    Args:
        image: np.array with image
        pred: bboxes

    Returns:
        image name

    """
    # create a figure and save it
    fig, ax = plt.subplots(1, figsize=(4, 4))
    ax.imshow(image)
    file_name = str(datetime.datetime.today().date()) + str(uuid.uuid1())
    fig.savefig(f'{file_name}.png')

    # dump bboxes in a local file
    with open(f'{file_name}.json', 'w') as j_f:
        json.dump({f'{file_name}.png': pred}, j_f)

    # upload the image and the bboxes to s3.
    save_object_to_s3(f'{file_name}.png', f'images/{file_name}.png')
    save_object_to_s3(f'{file_name}.json', f'labels/{file_name}.json')

    return file_name