Circularmachines's picture
updates
32716f3
raw
history blame contribute delete
No virus
5.18 kB
from datasets.arrow_dataset import InMemoryTable
import streamlit as st
from PIL import Image, ImageDraw
from streamlit_image_coordinates import streamlit_image_coordinates
import numpy as np
from datasets import load_dataset
st.set_page_config(layout="wide")
ds = load_dataset("Circularmachines/batch_indexing_machine_green_test", split="test")
patch_size=32
stride=16
#image_size=2304
image_size=512
gridsize=31
n_patches=961
#pred_dict={'Trained on color images (recommended)': np.load('pred_all_scratch.npy').reshape(-1,64),
# 'Trained on grayscale images': np.load('pred_all_grey.npy').reshape(-1,64)}
pred_dict={ 'Trained on augmented images 230809': np.load('pred_all_green_random.npy').reshape(-1,64),
'Trained on unaugmented images 230805': np.load('pred_all_scratch.npy').reshape(-1,64)}
random_i=np.load('random.npy')
if "point" not in st.session_state:
st.session_state["point"] = (128,64)
st.session_state["model"] = tuple(pred_dict.keys())[0]
if "img" not in st.session_state:
st.session_state["img"] = 0
if "draw" not in st.session_state:
st.session_state["draw"] = True
def patch(ij):
#st.write(ij)
immg=ij//n_patches
imm=ds[int(immg)]['image'].resize(size=(512,512))
p=ij%n_patches
y=p//gridsize
x=p%gridsize
imc=imm.crop(((x-1)*stride,(y-1)*stride,(x+3)*stride,(y+3)*stride))
return imc
def find():
st.session_state["sideix"] = []
point=st.session_state["point"]
point=(point[0]//stride,point[1]//stride)
#point=point[0]*36+point[1]
#st.write(point)
#st.write(pred_all[st.session_state["img"],point[0]*36+point[1]])
i=st.session_state["img"]
p=point[1]*gridsize+point[0]
diff=np.linalg.norm(pred_dict[st.session_state["model"]][np.newaxis,i*n_patches+p,:]-pred_dict[st.session_state["model"]],axis=-1)
#re_pred=pred_all.reshape(20,20,256,64)
#diff_re=diff.reshape((20,20,256)).argmin(axis=[])
i=0
ix=0
batches=[]
while ix<4:
batch=diff.argsort()[i]//n_patches//20
if batch not in batches:
batches.append(batch)
st.session_state["sideimg"][ix]=patch(diff.argsort()[i])
ix+=1
i+=1
st.session_state["sideix"]=batches
def button_click():
st.session_state["img"]=np.random.randint(100)
st.session_state["draw"] = False
if "sideimg" not in st.session_state:
st.session_state["sideimg"] = [patch(i) for i in range(4)]
if "sideix" not in st.session_state:
find()
def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
center = point
#patch_size
return (
center[0] ,
center[1] ,
center[0] + patch_size,
center[1] + patch_size,
)
col1, col2, col3= st.columns([3,1,1])
with col1:
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512))
draw = ImageDraw.Draw(current_image)
if st.session_state["draw"]:
# Draw an ellipse at each coordinate in points
#for point in st.session_state["points"]:
point=st.session_state["point"]
coords = get_ellipse_coords(point)
draw.rectangle(coords, outline="green",width=2)
value = streamlit_image_coordinates(current_image, key="pil")
if value is not None:
point = (value["x"]-8)//stride*stride, (value["y"]-8)//stride*stride
if point != st.session_state["point"]:
st.session_state["point"]=point
st.session_state["draw"]=True
st.experimental_rerun()
#subcol1, subcol2 = st.columns(2)
#with subcol1:
#st.button('Previous Frame', on_click=button_click)
scol1, scol2 = st.columns(2)
with scol1:
st.button('Change Image', on_click=button_click)
st.selectbox("Model",tuple(pred_dict.keys()),key="model")
with scol2:
st.button('Find similar parts', on_click=find)
#st.write("Currently viewing frame "+str(random_i[st.session_state["img"]%20])+" in batch "+str(st.session_state["img"]//20))
#st.write(st.session_state["img"])
#st.write(st.session_state["point"])
#st.write(st.session_state["draw"])
with col2:
# st.write("current selection:")
for i in [0,2]:
if i==0:
st.write("Target in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
else:
st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
st.image(st.session_state["sideimg"][i].resize((192,192)))
with col3:
# st.write("current selection:")
for i in [1,3]:
st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
st.image(st.session_state["sideimg"][i].resize((192,192)))
#st.write(st.session_state["sideix"][i])
#st.write(st.session_state["sideix"][i])
"The batch indexing machine shakes parts while recording a video."
"The machine processed 20 batches of random parts, with each batch running for 30 seconds."
"INSCTRUCTIONS:"
"Click in the image to set target part"
"Click “Find similar parts” to find the best matches in other batches"
"The model is trained completely unsupervised using a CNN with a custom contrastive loss."
"https://github.com/circularmachines/batch_indexing_machine/"
"johan.lagerloef@gmail.com"