from datasets.arrow_dataset import InMemoryTable import streamlit as st from PIL import Image, ImageDraw from streamlit_image_coordinates import streamlit_image_coordinates import numpy as np from datasets import load_dataset st.set_page_config(layout="wide") ds = load_dataset("Circularmachines/batch_indexing_machine_green_test", split="test") patch_size=32 stride=16 #image_size=2304 image_size=512 gridsize=31 n_patches=961 #pred_dict={'Trained on color images (recommended)': np.load('pred_all_scratch.npy').reshape(-1,64), # 'Trained on grayscale images': np.load('pred_all_grey.npy').reshape(-1,64)} pred_dict={ 'Trained on augmented images 230809': np.load('pred_all_green_random.npy').reshape(-1,64), 'Trained on unaugmented images 230805': np.load('pred_all_scratch.npy').reshape(-1,64)} random_i=np.load('random.npy') if "point" not in st.session_state: st.session_state["point"] = (128,64) st.session_state["model"] = tuple(pred_dict.keys())[0] if "img" not in st.session_state: st.session_state["img"] = 0 if "draw" not in st.session_state: st.session_state["draw"] = True def patch(ij): #st.write(ij) immg=ij//n_patches imm=ds[int(immg)]['image'].resize(size=(512,512)) p=ij%n_patches y=p//gridsize x=p%gridsize imc=imm.crop(((x-1)*stride,(y-1)*stride,(x+3)*stride,(y+3)*stride)) return imc def find(): st.session_state["sideix"] = [] point=st.session_state["point"] point=(point[0]//stride,point[1]//stride) #point=point[0]*36+point[1] #st.write(point) #st.write(pred_all[st.session_state["img"],point[0]*36+point[1]]) i=st.session_state["img"] p=point[1]*gridsize+point[0] diff=np.linalg.norm(pred_dict[st.session_state["model"]][np.newaxis,i*n_patches+p,:]-pred_dict[st.session_state["model"]],axis=-1) #re_pred=pred_all.reshape(20,20,256,64) #diff_re=diff.reshape((20,20,256)).argmin(axis=[]) i=0 ix=0 batches=[] while ix<4: batch=diff.argsort()[i]//n_patches//20 if batch not in batches: batches.append(batch) st.session_state["sideimg"][ix]=patch(diff.argsort()[i]) ix+=1 i+=1 st.session_state["sideix"]=batches def button_click(): st.session_state["img"]=np.random.randint(100) st.session_state["draw"] = False if "sideimg" not in st.session_state: st.session_state["sideimg"] = [patch(i) for i in range(4)] if "sideix" not in st.session_state: find() def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]): center = point #patch_size return ( center[0] , center[1] , center[0] + patch_size, center[1] + patch_size, ) col1, col2, col3= st.columns([3,1,1]) with col1: current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512)) draw = ImageDraw.Draw(current_image) if st.session_state["draw"]: # Draw an ellipse at each coordinate in points #for point in st.session_state["points"]: point=st.session_state["point"] coords = get_ellipse_coords(point) draw.rectangle(coords, outline="green",width=2) value = streamlit_image_coordinates(current_image, key="pil") if value is not None: point = (value["x"]-8)//stride*stride, (value["y"]-8)//stride*stride if point != st.session_state["point"]: st.session_state["point"]=point st.session_state["draw"]=True st.experimental_rerun() #subcol1, subcol2 = st.columns(2) #with subcol1: #st.button('Previous Frame', on_click=button_click) scol1, scol2 = st.columns(2) with scol1: st.button('Change Image', on_click=button_click) st.selectbox("Model",tuple(pred_dict.keys()),key="model") with scol2: st.button('Find similar parts', on_click=find) #st.write("Currently viewing frame "+str(random_i[st.session_state["img"]%20])+" in batch "+str(st.session_state["img"]//20)) #st.write(st.session_state["img"]) #st.write(st.session_state["point"]) #st.write(st.session_state["draw"]) with col2: # st.write("current selection:") for i in [0,2]: if i==0: st.write("Target in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20)) else: st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20)) st.image(st.session_state["sideimg"][i].resize((192,192))) with col3: # st.write("current selection:") for i in [1,3]: st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20)) st.image(st.session_state["sideimg"][i].resize((192,192))) #st.write(st.session_state["sideix"][i]) #st.write(st.session_state["sideix"][i]) "The batch indexing machine shakes parts while recording a video." "The machine processed 20 batches of random parts, with each batch running for 30 seconds." "INSCTRUCTIONS:" "Click in the image to set target part" "Click “Find similar parts” to find the best matches in other batches" "The model is trained completely unsupervised using a CNN with a custom contrastive loss." "https://github.com/circularmachines/batch_indexing_machine/" "johan.lagerloef@gmail.com"