|
from datasets.arrow_dataset import InMemoryTable |
|
import streamlit as st |
|
from PIL import Image, ImageDraw |
|
|
|
from streamlit_image_coordinates import streamlit_image_coordinates |
|
|
|
import numpy as np |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test") |
|
|
|
|
|
|
|
|
|
|
|
patch_size=32 |
|
stride=16 |
|
|
|
image_size=512 |
|
gridsize=31 |
|
|
|
n_patches=961 |
|
|
|
|
|
pred_all=np.load('pred_all.npy').reshape(-1,64) |
|
|
|
random_i=np.load('random.npy') |
|
|
|
|
|
if "point" not in st.session_state: |
|
st.session_state["point"] = (128,64) |
|
|
|
if "img" not in st.session_state: |
|
st.session_state["img"] = 0 |
|
|
|
if "draw" not in st.session_state: |
|
st.session_state["draw"] = True |
|
|
|
def patch(ij): |
|
|
|
immg=ij//n_patches |
|
|
|
|
|
imm=ds[int(immg)]['image'].resize(size=(512,512)) |
|
|
|
p=ij%n_patches |
|
y=p//gridsize |
|
x=p%gridsize |
|
imc=imm.crop(((x-1)*stride,(y-1)*stride,(x+3)*stride,(y+3)*stride)) |
|
|
|
return imc |
|
|
|
def find(): |
|
st.session_state["sideix"] = [] |
|
point=st.session_state["point"] |
|
point=(point[0]//stride,point[1]//stride) |
|
|
|
|
|
|
|
i=st.session_state["img"] |
|
p=point[1]*gridsize+point[0] |
|
diff=np.linalg.norm(pred_all[np.newaxis,i*n_patches+p,:]-pred_all,axis=-1) |
|
|
|
|
|
i=0 |
|
ix=0 |
|
batches=[] |
|
while ix<4: |
|
|
|
batch=diff.argsort()[i]//n_patches//20 |
|
|
|
if batch not in batches: |
|
|
|
batches.append(batch) |
|
|
|
st.session_state["sideimg"][ix]=patch(diff.argsort()[i]) |
|
ix+=1 |
|
|
|
i+=1 |
|
|
|
st.session_state["sideix"]=batches |
|
|
|
|
|
def button_click(): |
|
st.session_state["img"]=np.random.randint(100) |
|
st.session_state["draw"] = False |
|
|
|
if "sideimg" not in st.session_state: |
|
st.session_state["sideimg"] = [patch(i) for i in range(4)] |
|
|
|
if "sideix" not in st.session_state: |
|
find() |
|
|
|
def get_ellipse_coords(point): |
|
center = point |
|
|
|
return ( |
|
center[0] , |
|
center[1] , |
|
center[0] + patch_size, |
|
center[1] + patch_size, |
|
) |
|
|
|
|
|
"The batch indexing machine shakes parts while recording a video." |
|
"The machine processed 20 batches of random parts, with each batch running for 30 seconds." |
|
|
|
|
|
|
|
|
|
|
|
"The model is trained completely unsupervised using a CNN with a custom contrastive loss. Open source code to be released soon. " |
|
|
|
|
|
|
|
col1, col2 = st.columns([5,1]) |
|
|
|
with col1: |
|
|
|
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512)) |
|
draw = ImageDraw.Draw(current_image) |
|
|
|
if st.session_state["draw"]: |
|
|
|
|
|
|
|
point=st.session_state["point"] |
|
coords = get_ellipse_coords(point) |
|
draw.rectangle(coords, outline="green",width=2) |
|
|
|
value = streamlit_image_coordinates(current_image, key="pil") |
|
|
|
if value is not None: |
|
point = (value["x"]-8)//stride*stride, (value["y"]-8)//stride*stride |
|
|
|
if point != st.session_state["point"]: |
|
st.session_state["point"]=point |
|
st.session_state["draw"]=True |
|
st.experimental_rerun() |
|
|
|
|
|
|
|
|
|
scol1, scol2 = st.columns(2) |
|
with scol1: |
|
st.button('Change Image', on_click=button_click) |
|
|
|
with scol2: |
|
st.button('Find similar parts', on_click=find) |
|
|
|
st.write("Currently viewing frame "+str(random_i[st.session_state["img"]%20])+" in batch "+str(st.session_state["img"]//20)) |
|
|
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
for i in range(4): |
|
|
|
if i==0: |
|
st.write("current selection in batch "+str(st.session_state["sideix"][i])) |
|
if i==1: |
|
st.write("Best match found in batch "+str(st.session_state["sideix"][i])) |
|
if i==2: |
|
st.write("Second best match found in batch "+str(st.session_state["sideix"][i])) |
|
if i==3: |
|
st.write("Third best match found in batch "+str(st.session_state["sideix"][i])) |
|
|
|
st.image(st.session_state["sideimg"][i].resize((128,128))) |
|
|
|
|
|
|
|
"johan.lagerloef@gmail.com" |
|
|
|
|