File size: 4,942 Bytes
9705209
df20613
beb3e34
b73b4e1
 
df20613
8cc7c1a
26b8b00
31b5ef1
66b9c9d
540d02b
838f247
540d02b
 
 
8996350
540d02b
 
31b5ef1
540d02b
83f2c3c
66b9c9d
 
 
 
 
540d02b
66b9c9d
26b8b00
8996350
dc7c18b
c954bb1
26b8b00
66b9c9d
83f2c3c
 
faf8b4b
4f45f95
dc0f5b0
26b8b00
775fea9
 
42bc3b0
3d79c89
ba82fe8
3d79c89
dc0f5b0
3bc8e5f
 
fdef182
b355815
66b9c9d
584248d
83f2c3c
4f0f93d
83f2c3c
584248d
 
 
 
 
83f2c3c
 
8015cdb
f1ecf3e
faf8b4b
4549b91
cc301d6
1ac8545
 
 
6000267
3bc8e5f
065bb16
 
 
 
 
 
 
66b9c9d
065bb16
 
 
 
 
 
 
 
 
 
f1ecf3e
83f2c3c
 
f1ecf3e
 
 
faf8b4b
f1ecf3e
 
8015cdb
f1ecf3e
 
8cc7c1a
df5841e
 
33632ae
df5841e
e86e66f
 
de70cb6
 
df5841e
beb3e34
854f030
dd657c6
524c7bd
ed36122
391663d
12ae958
3d79c89
9b0c7a0
 
86efe2e
12ae958
854f030
12ae958
854f030
12ae958
 
faf8b4b
12ae958
 
aca188e
12ae958
aca188e
12ae958
9303628
df5841e
faf8b4b
 
12ae958
 
df5841e
e2813d1
 
5ff37df
561ccb3
 
 
 
a09f857
dc0f5b0
552a48c
561ccb3
 
 
9b0c7a0
12ae958
e5bac18
 
 
12ae958
 
d4c5cf6
7f0d2be
9acca80
d4c5cf6
9b0c7a0
 
 
 
1d85c4d
7f0d2be
 
 
 
 
 
 
 
1d85c4d
7f0d2be
 
d4c5cf6
065bb16
12ae958
8997343
 
 
9b0c7a0
 
 
8997343
 
 
 
 
 
dd657c6
 
12ae958
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from datasets.arrow_dataset import InMemoryTable
import streamlit as st
from PIL import Image, ImageDraw

from streamlit_image_coordinates import streamlit_image_coordinates

import numpy as np


from datasets import load_dataset

st.set_page_config(layout="wide")



ds = load_dataset("Circularmachines/batch_indexing_machine_green_test", split="test")





patch_size=32
stride=16
#image_size=2304
image_size=512
gridsize=31

n_patches=961

pred_dict={'Trained on color images': np.load('pred_all_green.npy').reshape(-1,64),
           'Trained on chromakey images': np.load('pred_all_chroma.npy').reshape(-1,64)}


random_i=np.load('random.npy')


if "point" not in st.session_state:
  st.session_state["point"] = (128,64)
  st.session_state["model"] = tuple(pred_dict.keys())[0]

if "img" not in st.session_state:
  st.session_state["img"] = 0

if "draw" not in st.session_state:
  st.session_state["draw"] = True




def patch(ij):
  #st.write(ij)
  immg=ij//n_patches
 

  imm=ds[int(immg)]['image'].resize(size=(512,512))

  p=ij%n_patches
  y=p//gridsize
  x=p%gridsize
  imc=imm.crop(((x-1)*stride,(y-1)*stride,(x+3)*stride,(y+3)*stride))

  return imc

def find():
  st.session_state["sideix"] = []
  point=st.session_state["point"]
  point=(point[0]//stride,point[1]//stride)
  #point=point[0]*36+point[1]
  #st.write(point)
  #st.write(pred_all[st.session_state["img"],point[0]*36+point[1]])
  i=st.session_state["img"]
  p=point[1]*gridsize+point[0]
  diff=np.linalg.norm(pred_dict[st.session_state["model"]][np.newaxis,i*n_patches+p,:]-pred_dict[st.session_state["model"]],axis=-1)
  #re_pred=pred_all.reshape(20,20,256,64)
  #diff_re=diff.reshape((20,20,256)).argmin(axis=[])
  i=0
  ix=0
  batches=[]
  while ix<4:

    batch=diff.argsort()[i]//n_patches//20

    if batch not in batches:

      batches.append(batch)

      st.session_state["sideimg"][ix]=patch(diff.argsort()[i])
      ix+=1
    
    i+=1
  
  st.session_state["sideix"]=batches


def button_click():
    st.session_state["img"]=np.random.randint(100)
    st.session_state["draw"] = False

if "sideimg" not in st.session_state:
  st.session_state["sideimg"] = [patch(i) for i in range(4)]

if "sideix" not in st.session_state:
  find()

def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
    center = point
    #patch_size
    return (
        center[0] ,
        center[1] ,
        center[0] + patch_size,
        center[1] + patch_size,
    )




col1, col2, col3= st.columns([3,1,1])

with col1:



  current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512))
  draw = ImageDraw.Draw(current_image)

  if st.session_state["draw"]:

  # Draw an ellipse at each coordinate in points
    #for point in st.session_state["points"]:
    point=st.session_state["point"]
    coords = get_ellipse_coords(point)
    draw.rectangle(coords, outline="green",width=2)

  value = streamlit_image_coordinates(current_image, key="pil")

  if value is not None:
      point = (value["x"]-8)//stride*stride, (value["y"]-8)//stride*stride

      if point != st.session_state["point"]:
          st.session_state["point"]=point
          st.session_state["draw"]=True
          st.experimental_rerun()

  #subcol1, subcol2 = st.columns(2)
  #with subcol1:
  #st.button('Previous Frame', on_click=button_click)
  scol1, scol2 = st.columns(2)
  with scol1:
    st.button('Change Image', on_click=button_click)

    st.selectbox("Model",tuple(pred_dict.keys()),key="model")


  with scol2: 
    st.button('Find similar parts', on_click=find)

  #st.write("Currently viewing frame "+str(random_i[st.session_state["img"]%20])+" in batch "+str(st.session_state["img"]//20))

  #st.write(st.session_state["img"])
  #st.write(st.session_state["point"])
  #st.write(st.session_state["draw"])

with col2:
 # st.write("current selection:")
  for i in [0,2]:
    
    if i==0:
      st.write("Target in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
    else:
      st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
   
    st.image(st.session_state["sideimg"][i].resize((192,192)))


with col3:
 # st.write("current selection:")
  for i in [1,3]:

    st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
   
    st.image(st.session_state["sideimg"][i].resize((192,192)))
    
    #st.write(st.session_state["sideix"][i])
    
    #st.write(st.session_state["sideix"][i])

"The batch indexing machine shakes parts while recording a video."
"The machine processed 20 batches of random parts, with each batch running for 30 seconds."

"INSCTRUCTIONS:"
"Click in the image to set target part"
"Click “Find similar parts” to find the best matches in other batches"

"The model is trained completely unsupervised using a CNN with a custom contrastive loss. Open source code to be released soon. "




"johan.lagerloef@gmail.com"