Spaces:
Running
Running
File size: 5,080 Bytes
a950ee6 89939b6 a950ee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageEnhance, ImageDraw
import torch
import streamlit as st
from model.inference_cpu import inference_case
initial_rectangle = {
"version": "4.4.0",
'objects': [
{
"type": "rect",
"version": "4.4.0",
"originX": "left",
"originY": "top",
"left": 50,
"top": 50,
"width": 100,
"height": 100,
'fill': 'rgba(255, 165, 0, 0.3)',
'stroke': '#2909F1',
'strokeWidth': 3,
'strokeDashArray': None,
'strokeLineCap': 'butt',
'strokeDashOffset': 0,
'strokeLineJoin': 'miter',
'strokeUniform': True,
'strokeMiterLimit': 4,
'scaleX': 1,
'scaleY': 1,
'angle': 0,
'flipX': False,
'flipY': False,
'opacity': 1,
'shadow': None,
'visible': True,
'backgroundColor': '',
'fillRule':
'nonzero',
'paintFirst':
'fill',
'globalCompositeOperation': 'source-over',
'skewX': 0,
'skewY': 0,
'rx': 0,
'ry': 0
}
]
}
def run():
image = st.session_state.data_item["image"].float()
image_zoom_out = st.session_state.data_item["zoom_out_image"].float()
text_prompt = None
point_prompt = None
box_prompt = None
if st.session_state.use_text_prompt:
text_prompt = st.session_state.text_prompt
if st.session_state.use_point_prompt and len(st.session_state.points) > 0:
point_prompt = reflect_points_into_model(st.session_state.points)
if st.session_state.use_box_prompt:
box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox)
inference_case.clear()
st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out,
text_prompt=text_prompt,
_point_prompt=point_prompt,
_box_prompt=box_prompt)
def reflect_box_into_model(box_3d):
z1, y1, x1, z2, y2, x2 = box_3d
x1_prompt = int(x1 * 256.0 / 325.0)
y1_prompt = int(y1 * 256.0 / 325.0)
z1_prompt = int(z1 * 32.0 / 325.0)
x2_prompt = int(x2 * 256.0 / 325.0)
y2_prompt = int(y2 * 256.0 / 325.0)
z2_prompt = int(z2 * 32.0 / 325.0)
return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt]))
def reflect_json_data_to_3D_box(json_data, view):
if view == 'xy':
st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top']
st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left']
st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY']
st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX']
print(st.session_state.rectangle_3Dbox)
def reflect_points_into_model(points):
points_prompt_list = []
for point in points:
z, y, x = point
x_prompt = int(x * 256.0 / 325.0)
y_prompt = int(y * 256.0 / 325.0)
z_prompt = int(z * 32.0 / 325.0)
points_prompt_list.append([z_prompt, y_prompt, x_prompt])
points_prompt = np.array(points_prompt_list)
points_label = np.ones(points_prompt.shape[0])
print(points_prompt, points_label)
return (torch.tensor(points_prompt), torch.tensor(points_label))
def show_points(points_ax, points_label, ax):
color = 'red' if points_label == 0 else 'blue'
ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200)
def make_fig(image, preds, point_axs=None, current_idx=None, view=None):
# Convert A to an image
image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB")
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(2.0)
# Create a yellow mask from B
if preds is not None:
mask = np.where(preds == 1, 255, 0).astype(np.uint8)
mask = Image.merge("RGB",
(Image.fromarray(mask),
Image.fromarray(mask),
Image.fromarray(np.zeros_like(mask, dtype=np.uint8))))
# Overlay the mask on the image
image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency)
if point_axs is not None:
draw = ImageDraw.Draw(image)
radius = 5
for point in point_axs:
z, y, x = point
if view == 'xy' and z == current_idx:
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue")
elif view == 'xz'and y == current_idx:
draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue")
return image |