Circularmachines's picture
updates
af55614
raw
history blame
4.44 kB
from datasets.arrow_dataset import InMemoryTable
import streamlit as st
from PIL import Image, ImageDraw
from streamlit_image_coordinates import streamlit_image_coordinates
import numpy as np
from datasets import load_dataset
ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test")
patch_size=32
#image_size=2304
image_size=512
gridsize=16
def donut(patch_size, img_size, lower_limit=0.40, upper_limit=1):
gridsize=img_size//2//patch_size
#create a grid of patch coordinates relative to center of image, and calculate distance from center
coords=np.array([[(i+0.5,j+0.5) for i in range(-gridsize,gridsize)] for j in range(-gridsize,gridsize)])
norm=np.linalg.norm(coords,axis=2)
#we are only interested in the "donut" where the parts are, anything close to the center and far from the center is disregarded
keep_bool=((norm>(gridsize*lower_limit))*(norm<(gridsize*upper_limit)))
keep=np.where(keep_bool.flatten())[0]
return coords,keep,keep_bool
coords,keep,keep_bool=donut(patch_size,image_size)
#coords_valid=coords.reshape(-1,2)[keep]
n_patches=len(keep)
#angle_sort=(-np.arctan2(coords_valid[:,0],coords_valid[:,1])).argsort()
#keep_a=keep[angle_sort]
#keep_i=np.zeros(gridsize**2)
#keep_i[keep]=keep_a
pred=np.load('pred.npy')
pred_all=np.load('pred_all.npy').reshape(-1,64)
#st.set_page_config(
# page_title="Streamlit Image Coordinates: Image Update",
# page_icon="🎯",
# layout="wide",
#)
#"# :dart: Streamlit Image Coordinates: Image Update"
if "point" not in st.session_state:
st.session_state["point"] = (200,200)
if "img" not in st.session_state:
st.session_state["img"] = 0
if "draw" not in st.session_state:
st.session_state["draw"] = False
def patch(ij):
#st.write(ij)
immg=ij//(gridsize**2)
p=ij%(gridsize**2)
imm=ds[int(immg)]['image'].resize(size=(512,512))
y=p//gridsize
x=p%gridsize
imc=imm.crop(((x-1)*patch_size,(y-1)*patch_size,(x+2)*patch_size,(y+2)*patch_size))
return imc
if "sideimg" not in st.session_state:
st.session_state["sideimg"] = [patch(i) for i in range(4)]
if "sideix" not in st.session_state:
st.session_state["sideix"] = [i for i in range(4)]
def button_click():
st.session_state["img"]=np.random.randint(100)
st.session_state["draw"] = False
def find():
point=st.session_state["point"]
point=(point[0]//patch_size,point[1]//patch_size)
#point=point[0]*36+point[1]
#st.write(point)
#st.write(pred_all[st.session_state["img"],point[0]*36+point[1]])
i=st.session_state["img"]
p=point[1]*gridsize+point[0]
diff=np.linalg.norm(pred_all[np.newaxis,i*gridsize**2+p,:]-pred_all,axis=-1)
for ix in range(4):
st.session_state["sideimg"][ix]=patch(diff.argsort()[ix])
st.session_state["sideix"][ix]=diff.argsort()[ix]
#st.write(diff.argsort()[ix])
# for i in range(4):
# st.session_state["sideimg"][i]+=1
# st.image(ds[0]['image'])
def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
center = point
#patch_size
return (
center[0] ,
center[1] ,
center[0] + patch_size,
center[1] + patch_size,
)
col1, col2 = st.columns([5,1])
with col1:
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512))
draw = ImageDraw.Draw(current_image)
if st.session_state["draw"]:
# Draw an ellipse at each coordinate in points
#for point in st.session_state["points"]:
point=st.session_state["point"]
coords = get_ellipse_coords(point)
draw.rectangle(coords, outline="green",width=2)
value = streamlit_image_coordinates(current_image, key="pil")
if value is not None:
point = value["x"]//patch_size*patch_size, value["y"]//patch_size*patch_size
if point != st.session_state["point"]:
st.session_state["point"]=point
st.session_state["draw"]=True
st.experimental_rerun()
#subcol1, subcol2 = st.columns(2)
#with subcol1:
#st.button('Previous Frame', on_click=button_click)
st.button('Change Batch', on_click=button_click)
st.button('Find similar parts', on_click=find)
st.write(st.session_state["img"])
st.write(st.session_state["point"])
st.write(st.session_state["draw"])
with col2:
for i in range(4):
st.image(st.session_state["sideimg"][i].resize((128,128)))
st.write(st.session_state["sideix"][i]//n_patches//20)
st.write(st.session_state["sideix"][i])