Spaces:
Runtime error
Runtime error
File size: 3,104 Bytes
9ab711c 2c16c94 9ab711c b683920 9ab711c b683920 cf41825 9ab711c 2c16c94 9ab711c cdb8c6a 2c16c94 cdb8c6a b683920 cf41825 b683920 99d3d67 b683920 ca67f90 b683920 9ab711c f0facef 9ab711c b683920 9ab711c b683920 9ab711c b683920 9ab711c 2c16c94 9ab711c b683920 9ab711c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import datetime
import json
import os
import uuid
from typing import List
import boto3
import matplotlib
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy.typing as npt
import streamlit as st
import tomli
AWS_ACCESS_KEY_ID = ''
AWS_SECRET_ACCESS_KEY = ''
try:
if st.secrets is not None:
AWS_ACCESS_KEY_ID = st.secrets['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = st.secrets['AWS_SECRET_ACCESS_KEY']
except BaseException:
pass
if os.path.exists('config.toml'):
with open('config.toml', 'rb') as f:
config = tomli.load(f)
AWS_ACCESS_KEY_ID = config['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS_SECRET_ACCESS_KEY']
client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
def plot_img_with_rects(
img: npt.ArrayLike, boxes: List[List], threshold: float = 0.5, coef: int = 400
) -> matplotlib.figure.Figure:
"""
Plot image with rectangles.
Args:
img: image as a numpy array
boxes: the list of the bboxes
threshold: threshold for bbox probability
coef: coefficient to multiply images. Can be changed when the original image is a different size
Returns:
image with bboxes
"""
fig, ax = plt.subplots(1, figsize=(4, 4))
# Display the image
ax.imshow(img)
# Create a Rectangle patch
for _, rect in enumerate(b for b in boxes if b[1] > threshold):
label, _, xc, yc, w, h = rect
xc, yc, w, h = xc * coef, yc * coef, w * coef, h * coef
# the coordinates from center-based to left top corner
x = xc - w / 2
y = yc - h / 2
label = int(label)
label = label if label != 10 else 'censored'
label = label if label != 11 else 'other'
rect = [x, y, x + w, y + h]
rect_ = patches.Rectangle(
(rect[0], rect[1]), rect[2] - rect[0], rect[3] - rect[1], linewidth=2, edgecolor='blue', facecolor='none'
)
plt.text(rect[2], rect[1], f'{label}', color='blue')
# Add the patch to the Axes
ax.add_patch(rect_)
return fig
def save_object_to_s3(filename, s3_filename):
client.upload_file(filename, 'digitdrawdetect', s3_filename)
@st.cache_data(show_spinner=False)
def save_image(image: npt.ArrayLike, pred: List[List]) -> str:
"""
Save the image and upload the image with bboxes to s3.
Args:
image: np.array with image
pred: bboxes
Returns:
image name
"""
# create a figure and save it
fig, ax = plt.subplots(1, figsize=(4, 4))
ax.imshow(image)
file_name = str(datetime.datetime.today().date()) + str(uuid.uuid1())
fig.savefig(f'{file_name}.png')
# dump bboxes in a local file
with open(f'{file_name}.json', 'w') as j_f:
json.dump({f'{file_name}.png': pred}, j_f)
# upload the image and the bboxes to s3.
save_object_to_s3(f'{file_name}.png', f'images/{file_name}.png')
save_object_to_s3(f'{file_name}.json', f'labels/{file_name}.json')
return file_name
|