File size: 4,942 Bytes
9705209 df20613 beb3e34 b73b4e1 df20613 8cc7c1a 26b8b00 31b5ef1 66b9c9d 540d02b 838f247 540d02b 8996350 540d02b 31b5ef1 540d02b 83f2c3c 66b9c9d 540d02b 66b9c9d 26b8b00 8996350 dc7c18b c954bb1 26b8b00 66b9c9d 83f2c3c faf8b4b 4f45f95 dc0f5b0 26b8b00 775fea9 42bc3b0 3d79c89 ba82fe8 3d79c89 dc0f5b0 3bc8e5f fdef182 b355815 66b9c9d 584248d 83f2c3c 4f0f93d 83f2c3c 584248d 83f2c3c 8015cdb f1ecf3e faf8b4b 4549b91 cc301d6 1ac8545 6000267 3bc8e5f 065bb16 66b9c9d 065bb16 f1ecf3e 83f2c3c f1ecf3e faf8b4b f1ecf3e 8015cdb f1ecf3e 8cc7c1a df5841e 33632ae df5841e e86e66f de70cb6 df5841e beb3e34 854f030 dd657c6 524c7bd ed36122 391663d 12ae958 3d79c89 9b0c7a0 86efe2e 12ae958 854f030 12ae958 854f030 12ae958 faf8b4b 12ae958 aca188e 12ae958 aca188e 12ae958 9303628 df5841e faf8b4b 12ae958 df5841e e2813d1 5ff37df 561ccb3 a09f857 dc0f5b0 552a48c 561ccb3 9b0c7a0 12ae958 e5bac18 12ae958 d4c5cf6 7f0d2be 9acca80 d4c5cf6 9b0c7a0 1d85c4d 7f0d2be 1d85c4d 7f0d2be d4c5cf6 065bb16 12ae958 8997343 9b0c7a0 8997343 dd657c6 12ae958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
from datasets.arrow_dataset import InMemoryTable
import streamlit as st
from PIL import Image, ImageDraw
from streamlit_image_coordinates import streamlit_image_coordinates
import numpy as np
from datasets import load_dataset
st.set_page_config(layout="wide")
ds = load_dataset("Circularmachines/batch_indexing_machine_green_test", split="test")
patch_size=32
stride=16
#image_size=2304
image_size=512
gridsize=31
n_patches=961
pred_dict={'Trained on color images': np.load('pred_all_green.npy').reshape(-1,64),
'Trained on chromakey images': np.load('pred_all_chroma.npy').reshape(-1,64)}
random_i=np.load('random.npy')
if "point" not in st.session_state:
st.session_state["point"] = (128,64)
st.session_state["model"] = tuple(pred_dict.keys())[0]
if "img" not in st.session_state:
st.session_state["img"] = 0
if "draw" not in st.session_state:
st.session_state["draw"] = True
def patch(ij):
#st.write(ij)
immg=ij//n_patches
imm=ds[int(immg)]['image'].resize(size=(512,512))
p=ij%n_patches
y=p//gridsize
x=p%gridsize
imc=imm.crop(((x-1)*stride,(y-1)*stride,(x+3)*stride,(y+3)*stride))
return imc
def find():
st.session_state["sideix"] = []
point=st.session_state["point"]
point=(point[0]//stride,point[1]//stride)
#point=point[0]*36+point[1]
#st.write(point)
#st.write(pred_all[st.session_state["img"],point[0]*36+point[1]])
i=st.session_state["img"]
p=point[1]*gridsize+point[0]
diff=np.linalg.norm(pred_dict[st.session_state["model"]][np.newaxis,i*n_patches+p,:]-pred_dict[st.session_state["model"]],axis=-1)
#re_pred=pred_all.reshape(20,20,256,64)
#diff_re=diff.reshape((20,20,256)).argmin(axis=[])
i=0
ix=0
batches=[]
while ix<4:
batch=diff.argsort()[i]//n_patches//20
if batch not in batches:
batches.append(batch)
st.session_state["sideimg"][ix]=patch(diff.argsort()[i])
ix+=1
i+=1
st.session_state["sideix"]=batches
def button_click():
st.session_state["img"]=np.random.randint(100)
st.session_state["draw"] = False
if "sideimg" not in st.session_state:
st.session_state["sideimg"] = [patch(i) for i in range(4)]
if "sideix" not in st.session_state:
find()
def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
center = point
#patch_size
return (
center[0] ,
center[1] ,
center[0] + patch_size,
center[1] + patch_size,
)
col1, col2, col3= st.columns([3,1,1])
with col1:
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512))
draw = ImageDraw.Draw(current_image)
if st.session_state["draw"]:
# Draw an ellipse at each coordinate in points
#for point in st.session_state["points"]:
point=st.session_state["point"]
coords = get_ellipse_coords(point)
draw.rectangle(coords, outline="green",width=2)
value = streamlit_image_coordinates(current_image, key="pil")
if value is not None:
point = (value["x"]-8)//stride*stride, (value["y"]-8)//stride*stride
if point != st.session_state["point"]:
st.session_state["point"]=point
st.session_state["draw"]=True
st.experimental_rerun()
#subcol1, subcol2 = st.columns(2)
#with subcol1:
#st.button('Previous Frame', on_click=button_click)
scol1, scol2 = st.columns(2)
with scol1:
st.button('Change Image', on_click=button_click)
st.selectbox("Model",tuple(pred_dict.keys()),key="model")
with scol2:
st.button('Find similar parts', on_click=find)
#st.write("Currently viewing frame "+str(random_i[st.session_state["img"]%20])+" in batch "+str(st.session_state["img"]//20))
#st.write(st.session_state["img"])
#st.write(st.session_state["point"])
#st.write(st.session_state["draw"])
with col2:
# st.write("current selection:")
for i in [0,2]:
if i==0:
st.write("Target in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
else:
st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
st.image(st.session_state["sideimg"][i].resize((192,192)))
with col3:
# st.write("current selection:")
for i in [1,3]:
st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
st.image(st.session_state["sideimg"][i].resize((192,192)))
#st.write(st.session_state["sideix"][i])
#st.write(st.session_state["sideix"][i])
"The batch indexing machine shakes parts while recording a video."
"The machine processed 20 batches of random parts, with each batch running for 30 seconds."
"INSCTRUCTIONS:"
"Click in the image to set target part"
"Click “Find similar parts” to find the best matches in other batches"
"The model is trained completely unsupervised using a CNN with a custom contrastive loss. Open source code to be released soon. "
"johan.lagerloef@gmail.com"
|