File size: 3,399 Bytes
9705209 df20613 beb3e34 b73b4e1 df20613 8cc7c1a 26b8b00 31b5ef1 86efe2e 31b5ef1 86efe2e 33632ae 8797870 73dca24 775cb18 31b5ef1 83f2c3c 33632ae 83f2c3c 8cc7c1a 26b8b00 83f2c3c faf8b4b 26b8b00 775fea9 42bc3b0 3d79c89 fdef182 2814195 33632ae 83f2c3c 9e5d34c 83f2c3c 33632ae 2814195 83f2c3c 12ae958 83f2c3c 26b8b00 b2b0a36 775fea9 fca50af 8015cdb faf8b4b 33632ae cc301d6 1ac8545 33632ae fdef182 0b03953 fdef182 83f2c3c faf8b4b 12ae958 8015cdb 8cc7c1a df5841e 33632ae df5841e e86e66f de70cb6 df5841e beb3e34 854f030 aa4905a 391663d 12ae958 3d79c89 86efe2e 12ae958 854f030 12ae958 854f030 12ae958 faf8b4b 12ae958 aca188e 12ae958 aca188e 12ae958 33632ae df5841e faf8b4b 12ae958 df5841e e2813d1 5ff37df e2813d1 12ae958 faf8b4b 12ae958 ddac224 12ae958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from datasets.arrow_dataset import InMemoryTable
import streamlit as st
from PIL import Image, ImageDraw
from streamlit_image_coordinates import streamlit_image_coordinates
import numpy as np
from datasets import load_dataset
ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test")
gridsize=16
n_patches=164
patch_size=32
pred=np.load('pred.npy')
pred_all=np.load('pred_all.npy').reshape(-1,64)
keep_bool=np.load('keep_bool.npy')
keep=np.where(keep_bool.flatten())[0]
keep_i=np.zeros(gridsize**2)
keep_i[keep]=keep
#st.set_page_config(
# page_title="Streamlit Image Coordinates: Image Update",
# page_icon="🎯",
# layout="wide",
#)
#"# :dart: Streamlit Image Coordinates: Image Update"
if "point" not in st.session_state:
st.session_state["point"] = (200,200)
if "img" not in st.session_state:
st.session_state["img"] = 0
if "draw" not in st.session_state:
st.session_state["draw"] = False
def patch(ij):
#st.write(ij)
immg=ij//(gridsize**2)
p=ij%(gridsize**2)
imm=ds[int(immg)]['image'].resize(size=(512,512))
y=p//gridsize
x=p%gridsize
imc=imm.crop(((x-1)*patch_size,(y-1)*patch_size,(x+2)*patch_size,(y+2)*patch_size))
return imc
if "sideimg" not in st.session_state:
st.session_state["sideimg"] = [patch(i) for i in range(4)]
def button_click():
st.session_state["img"]=np.random.randint(100)
st.session_state["draw"] = False
def find():
point=st.session_state["point"]
point=(point[0]//patch_size,point[1]//patch_size)
#point=point[0]*36+point[1]
#st.write(point)
#st.write(pred_all[st.session_state["img"],point[0]*36+point[1]])
i=st.session_state["img"]
p=point[1]*gridsize+point[0]
diff=np.linalg.norm(pred_all[np.newaxis,i*gridsize**2+p,:]-pred_all,axis=-1)
for ix in range(4):
st.session_state["sideimg"][ix]=patch(diff.argsort()[ix])
#st.write(diff.argsort()[ix])
# for i in range(4):
# st.session_state["sideimg"][i]+=1
# st.image(ds[0]['image'])
def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
center = point
#patch_size
return (
center[0] ,
center[1] ,
center[0] + patch_size,
center[1] + patch_size,
)
col1, col2 = st.columns([5,1])
with col1:
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512))
draw = ImageDraw.Draw(current_image)
if st.session_state["draw"]:
# Draw an ellipse at each coordinate in points
#for point in st.session_state["points"]:
point=st.session_state["point"]
coords = get_ellipse_coords(point)
draw.rectangle(coords, outline="green",width=2)
value = streamlit_image_coordinates(current_image, key="pil")
if value is not None:
point = value["x"]//patch_size*patch_size, value["y"]//patch_size*patch_size
if point != st.session_state["point"]:
st.session_state["point"]=point
st.session_state["draw"]=True
st.experimental_rerun()
#subcol1, subcol2 = st.columns(2)
#with subcol1:
#st.button('Previous Frame', on_click=button_click)
st.button('Change Batch', on_click=button_click)
st.button('Find similar parts', on_click=find)
st.write(st.session_state["img"])
st.write(st.session_state["point"])
st.write(st.session_state["draw"])
with col2:
for i in range(4):
st.image(st.session_state["sideimg"][i].resize((128,128)))
|