vumichien commited on
Commit
091632c
β€’
1 Parent(s): dbf451c
Files changed (2) hide show
  1. app.py +6 -2
  2. utils.py +74 -71
app.py CHANGED
@@ -25,6 +25,11 @@ def load_lottieurl(url: str):
25
  return r.json()
26
 
27
 
 
 
 
 
 
28
  lottie_penguin = load_lottieurl('https://assets10.lottiefiles.com/datafiles/Yv8B88Go8kHRZ5T/data.json')
29
  st_lottie(lottie_penguin, height=200)
30
 
@@ -56,11 +61,10 @@ method = st.sidebar.radio('Choose input source πŸ‘‡', options=['Image', 'Webcam'
56
  def initial_setup():
57
  df_train = pd.read_csv('full_set.csv')
58
  sub_test_list = sorted(list(df_train['Image'].map(lambda x: get_image(x))))
59
- # embeddings = torch.load('embeddings.pt')
60
  with open('embeddings.npy', 'rb') as f:
61
  embeddings = np.load(f)
62
  PATH = 'model_onnx.onnx'
63
- ort_session = onnxruntime.InferenceSession(PATH)
64
  input_name = ort_session.get_inputs()[0].name
65
  return df_train, sub_test_list, embeddings, ort_session, input_name
66
 
 
25
  return r.json()
26
 
27
 
28
+ # Configure
29
+ options = onnxruntime.SessionOptions()
30
+ options.intra_op_num_threads = 8
31
+ options.inter_op_num_threads = 8
32
+
33
  lottie_penguin = load_lottieurl('https://assets10.lottiefiles.com/datafiles/Yv8B88Go8kHRZ5T/data.json')
34
  st_lottie(lottie_penguin, height=200)
35
 
 
61
  def initial_setup():
62
  df_train = pd.read_csv('full_set.csv')
63
  sub_test_list = sorted(list(df_train['Image'].map(lambda x: get_image(x))))
 
64
  with open('embeddings.npy', 'rb') as f:
65
  embeddings = np.load(f)
66
  PATH = 'model_onnx.onnx'
67
+ ort_session = onnxruntime.InferenceSession(PATH, sess_options=options)
68
  input_name = ort_session.get_inputs()[0].name
69
  return df_train, sub_test_list, embeddings, ort_session, input_name
70
 
utils.py CHANGED
@@ -1,6 +1,6 @@
1
- #import torch
2
- #import torch.nn.functional as F
3
- #from torchvision import transforms
4
 
5
  from PIL import Image
6
  import numpy as np
@@ -10,7 +10,7 @@ from numpy.linalg import norm
10
  import onnx, os, time, onnxruntime
11
  import pandas as pd
12
  import threading
13
- #import queue
14
  import cv2
15
  import av
16
 
@@ -25,17 +25,17 @@ from streamlit_webrtc import (
25
  import args
26
 
27
 
28
-
29
  # def to_numpy(tensor):
30
  # return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
31
 
 
32
  def get_image(x):
33
- return x.split(', ')[0]
 
34
 
35
  # Transform image to ToTensor
36
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
37
- def transform_image(image, IMG = True):
38
-
39
  # transform = transforms.Compose([
40
  # transforms.Resize((224, 224)),
41
  # transforms.ToTensor(),
@@ -45,10 +45,10 @@ def transform_image(image, IMG = True):
45
  image = np.asarray(Image.open(image))
46
  # -------------- RESIZE USING CV2 ---------------------
47
  image = cv2.resize(image, dsize=(224, 224))
48
- image = np.transpose(image, (2,0,1))
49
- #image = (image/255-np.expand_dims(np.array([0.485, 0.456, 0.4065]),axis = (1,2)))/np.expand_dims(np.array([0.229, 0.224, 0.225]),axis = (1,2))
50
- image = (image/255-np.array(args.MEAN))/np.array(args.STD)
51
- img_transformed = np.expand_dims(image.astype(np.float32), axis = 0)
52
  # x = torch.from_numpy(image.astype(np.float32))
53
  # x = torch.transpose(x, 2, 0) # shape [3, 224, 224]
54
  # -------------- RESIZE USING CV2 ---------------------
@@ -59,10 +59,10 @@ def transform_image(image, IMG = True):
59
  else:
60
  # -------------- RESIZE USING CV2 ---------------------
61
  image = cv2.resize(image, dsize=(224, 224))
62
- image = np.transpose(image, (2,0,1))
63
- #image = (image/255-np.expand_dims(np.array([0.485, 0.456, 0.4065]),axis = (1,2)))/np.expand_dims(np.array([0.229, 0.224, 0.225]),axis = (1,2))
64
- image = (image/255-np.array(args.MEAN))/np.array(args.STD)
65
- img_transformed = np.expand_dims(image.astype(np.float32), axis = 0)
66
  # x = torch.from_numpy(image.astype(np.float32))
67
  # x = torch.transpose(x, 2, 0)
68
  # -------------- RESIZE USING CV2 ---------------------
@@ -72,22 +72,24 @@ def transform_image(image, IMG = True):
72
 
73
  return img_transformed
74
 
 
75
  # predict multi-level classification
76
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
77
  def get_classification(image_tensor, df_train, sub_test_list, embeddings,
78
- ort_session, input_name, confidence
79
- ):
80
  # Prediction time
81
  start = time.time()
82
- #ort_inputs = {input_name: to_numpy(image_tensor)}
83
  ort_inputs = {input_name: image_tensor}
84
  pred, em = ort_session.run(None, ort_inputs)
85
 
86
- if pred.max(axis=1) > confidence: # threshold to select of item is car part or not, Yes if > 0.5
87
  # Compute kNN (using Cosine)
88
- #knn = torch.nn.CosineSimilarity(dim = 1)(torch.tensor(em), embeddings).topk(1, largest=True)
89
 
90
- knn = np.array([dot((em), embeddings[i])/(norm(em)*norm(embeddings[i])) for i in range(embeddings.shape[0])]).flatten()
 
91
  knn = np.argsort(knn)[-1]
92
 
93
  # maker = 'Maker: '+str(df_train.iloc[knn.indices.item(), 0])
@@ -95,77 +97,81 @@ def get_classification(image_tensor, df_train, sub_test_list, embeddings,
95
  # vehicle = str(df_train.iloc[knn.indices.item(), 2])
96
  # year = str(df_train.iloc[knn.indices.item(), 3])
97
  # part = 'Part: '+str(df_train.iloc[knn.indices.item(), 4])
98
- maker = 'Maker: '+str(df_train.iloc[knn, 0])
99
  model = str(df_train.iloc[knn, 1])
100
- if model=='nan':
101
- model='Model: No information'
102
  else:
103
- model='Model: '+model
104
  vehicle = str(df_train.iloc[knn, 2])
105
- if vehicle=='nan':
106
- vehicle='Vehicle: No information'
107
  else:
108
- vehicle='Vehicle: '+vehicle
109
  year = str(df_train.iloc[knn, 3])
110
- if year=='nan':
111
- year='Year: No information'
112
  else:
113
- year='Year: '+year
114
- part = 'Part: '+str(df_train.iloc[knn, 4])
115
- predict_time = 'Predict time: '+str(round(time.time() - start,4))+' seconds'
116
 
117
  # Similarity score
118
- sim_score = 'Confidence: '+str(round(pred.max(axis=1).item()*100, 2))+'%'
119
 
120
  else:
121
  maker = 'This is not car part !'
122
- model=vehicle=year=part=predict_time=sim_score=None
 
 
 
123
 
124
- return {'maker':maker,'model':model,'vehicle':vehicle,'year':year, 'part':part, 'predict_time':predict_time,'sim_score':sim_score}
125
 
126
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
127
  def get_classification_frame(image_tensor, df_train, sub_test_list, embeddings,
128
- ort_session, input_name
129
- ):
130
-
131
- #ort_inputs = {input_name: to_numpy(image_tensor)}
132
  ort_inputs = {input_name: image_tensor}
133
  pred, em = ort_session.run(None, ort_inputs)
134
 
135
  if pred.max(axis=1) > args.VIDEO_CONFIDENCE:
136
  # knn = torch.nn.CosineSimilarity(dim = 1)(torch.tensor(em), embeddings).topk(1, largest=True)
137
  # part = str(df_train.iloc[knn.indices.item(), 4])
138
- knn = np.array([dot((em), embeddings[i])/(norm(em)*norm(embeddings[i])) for i in range(embeddings.shape[0])]).flatten()
 
139
  knn = np.argsort(knn)[-1]
140
  part = str(df_train.iloc[knn, 4])
141
  # Similarity score
142
- sim_score = str(round(pred.max(axis=1).item()*100, 2))+'%'
143
  else:
144
  part = 'No part detected'
145
  sim_score = ''
146
 
147
- return {'part_name':part,'sim_score':sim_score}
 
148
 
149
  # predict similarity
150
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
151
  def get_similarity(image_tensor, df_train, sub_test_list, embeddings,
152
- ort_session, input_name
153
- ):
154
  start = time.time()
155
- #ort_inputs = {input_name: to_numpy(image_tensor)}
156
  ort_inputs = {input_name: image_tensor}
157
  pred, em = ort_session.run(None, ort_inputs)
158
 
159
  # Compute kNN (using Cosine)
160
- #knn = torch.nn.CosineSimilarity(dim = 1)(torch.tensor(em), embeddings).topk(6, largest=True)
161
  # idx = knn.indices.numpy()
162
- knn = np.array([dot((em), embeddings[i])/(norm(em)*norm(embeddings[i])) for i in range(embeddings.shape[0])]).flatten()
 
163
  idx = np.argsort(knn)[-6:]
164
- predict_time = 'Predict time: '+str(round(time.time() - start,4))+' seconds'
165
  images_path = 'Test_set'
166
  images = [os.path.join(images_path, sub_test_list[i]) for i in idx]
167
  # sub_test_list
168
- return {'images': images, 'predict_time':predict_time}
169
 
170
 
171
  # --------------------------------------------------------------------------------------------
@@ -173,13 +179,12 @@ def get_similarity(image_tensor, df_train, sub_test_list, embeddings,
173
  # --------------------------------------------------------------------------------------------
174
 
175
  content_images_dict = {
176
- name: os.path.join(args.IMAGES_PATH, filee) for name, filee in zip(args.CONTENT_IMAGES_NAME, args.CONTENT_IMAGES_FILE)
 
177
  }
178
 
179
 
180
-
181
  def show_original():
182
-
183
  """ Show Uploaded or Example image before prediction
184
 
185
  Returns:
@@ -188,7 +193,7 @@ def show_original():
188
  path to image
189
  """
190
 
191
- if st.sidebar.checkbox('Upload', value= True, help = 'Select Upload to browse image from local machine'):
192
  content_file = st.sidebar.file_uploader("", type=["png", "jpg", "jpeg"])
193
  else:
194
  content_name = st.sidebar.selectbox("or Choose an example Image below", args.CONTENT_IMAGES_NAME)
@@ -196,7 +201,7 @@ def show_original():
196
 
197
  col1, col2 = st.columns(2)
198
  with col1:
199
- #col1.markdown('## Target image')
200
  if content_file:
201
  col1.write('')
202
  col1.image(content_file, channels='BGR', width=300, clamp=True, caption='Input image')
@@ -205,7 +210,6 @@ def show_original():
205
 
206
 
207
  def image_input(content_file, df_train, sub_test_list, embeddings, ort_session, input_name, col2):
208
-
209
  # Set confidence level
210
  confidence_threshold = st.slider(
211
  "Confidence threshold", 0.0, 1.0, args.DEFAULT_CONFIDENCE_THRESHOLD, 0.05,
@@ -232,9 +236,9 @@ def image_input(content_file, df_train, sub_test_list, embeddings, ort_session,
232
  if col7.button("SEARCH SIMILAR"):
233
  print_classification(col2, content_file, pred_info)
234
 
235
- if pred_info['maker']!='This is not car part !':
236
- #container = st.container()
237
- print_similar_img(pred_images) #, container)
238
  else:
239
  st.warning("No similar car part image ! Reduce confidence threshold OR Choose another image.")
240
  else:
@@ -278,7 +282,7 @@ def webcam_input(df_train, sub_test_list, embeddings, ort_session, input_name):
278
 
279
  def recv(self, frame: av.VideoFrame) -> av.VideoFrame:
280
  image = frame.to_ndarray(format="bgr24")
281
- content = transform_image(image, IMG = False)
282
  pred_info = get_classification_frame(
283
  content, df_train, sub_test_list,
284
  embeddings, ort_session, input_name
@@ -297,14 +301,13 @@ def webcam_input(df_train, sub_test_list, embeddings, ort_session, input_name):
297
 
298
 
299
  def print_classification(col2, content_file, pred_info):
300
-
301
  """ Print classification prediction
302
  """
303
 
304
  with col2:
305
  col2.markdown('### Predicted information')
306
  col2.markdown('')
307
- if pred_info['maker']!='This is not car part !':
308
  col2.markdown('#### - {}'.format(pred_info['maker']))
309
  col2.markdown('#### - {}'.format(pred_info['model']))
310
  col2.markdown('#### - {}'.format(pred_info['vehicle']))
@@ -315,8 +318,8 @@ def print_classification(col2, content_file, pred_info):
315
  else:
316
  col2.markdown('### {}'.format(pred_info['maker']))
317
 
318
- def print_similar_img(pred_images):
319
 
 
320
  """ Print similarity images prediction
321
  """
322
 
@@ -325,13 +328,13 @@ def print_similar_img(pred_images):
325
 
326
  col3, col4, col5 = st.columns(3)
327
  with col3:
328
- col3.image(pred_images['images'][0], channels='BGR', clamp=True, width = 300)
329
- col3.image(pred_images['images'][1], channels='BGR', clamp=True, width = 300)
330
 
331
  with col4:
332
- #col4.markdown('# ')
333
- col4.image(pred_images['images'][3], channels='BGR', clamp=True, width = 300)
334
- col4.image(pred_images['images'][4], channels='BGR', clamp=True, width = 300)
335
  with col5:
336
- col5.image(pred_images['images'][5], channels='BGR', clamp=True, width = 300)
337
- col5.image(pred_images['images'][2], channels='BGR', clamp=True, width = 300)
 
1
+ # import torch
2
+ # import torch.nn.functional as F
3
+ # from torchvision import transforms
4
 
5
  from PIL import Image
6
  import numpy as np
 
10
  import onnx, os, time, onnxruntime
11
  import pandas as pd
12
  import threading
13
+ # import queue
14
  import cv2
15
  import av
16
 
 
25
  import args
26
 
27
 
 
28
  # def to_numpy(tensor):
29
  # return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
30
 
31
+
32
  def get_image(x):
33
+ return x.split(', ')[0]
34
+
35
 
36
  # Transform image to ToTensor
37
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
38
+ def transform_image(image, IMG=True):
 
39
  # transform = transforms.Compose([
40
  # transforms.Resize((224, 224)),
41
  # transforms.ToTensor(),
 
45
  image = np.asarray(Image.open(image))
46
  # -------------- RESIZE USING CV2 ---------------------
47
  image = cv2.resize(image, dsize=(224, 224))
48
+ image = np.transpose(image, (2, 0, 1))
49
+ # image = (image/255-np.expand_dims(np.array([0.485, 0.456, 0.4065]),axis = (1,2)))/np.expand_dims(np.array([0.229, 0.224, 0.225]),axis = (1,2))
50
+ image = (image / 255 - np.array(args.MEAN)) / np.array(args.STD)
51
+ img_transformed = np.expand_dims(image.astype(np.float32), axis=0)
52
  # x = torch.from_numpy(image.astype(np.float32))
53
  # x = torch.transpose(x, 2, 0) # shape [3, 224, 224]
54
  # -------------- RESIZE USING CV2 ---------------------
 
59
  else:
60
  # -------------- RESIZE USING CV2 ---------------------
61
  image = cv2.resize(image, dsize=(224, 224))
62
+ image = np.transpose(image, (2, 0, 1))
63
+ # image = (image/255-np.expand_dims(np.array([0.485, 0.456, 0.4065]),axis = (1,2)))/np.expand_dims(np.array([0.229, 0.224, 0.225]),axis = (1,2))
64
+ image = (image / 255 - np.array(args.MEAN)) / np.array(args.STD)
65
+ img_transformed = np.expand_dims(image.astype(np.float32), axis=0)
66
  # x = torch.from_numpy(image.astype(np.float32))
67
  # x = torch.transpose(x, 2, 0)
68
  # -------------- RESIZE USING CV2 ---------------------
 
72
 
73
  return img_transformed
74
 
75
+
76
  # predict multi-level classification
77
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
78
  def get_classification(image_tensor, df_train, sub_test_list, embeddings,
79
+ ort_session, input_name, confidence
80
+ ):
81
  # Prediction time
82
  start = time.time()
83
+ # ort_inputs = {input_name: to_numpy(image_tensor)}
84
  ort_inputs = {input_name: image_tensor}
85
  pred, em = ort_session.run(None, ort_inputs)
86
 
87
+ if pred.max(axis=1) > confidence: # threshold to select of item is car part or not, Yes if > 0.5
88
  # Compute kNN (using Cosine)
89
+ # knn = torch.nn.CosineSimilarity(dim = 1)(torch.tensor(em), embeddings).topk(1, largest=True)
90
 
91
+ knn = np.array(
92
+ [dot((em), embeddings[i]) / (norm(em) * norm(embeddings[i])) for i in range(embeddings.shape[0])]).flatten()
93
  knn = np.argsort(knn)[-1]
94
 
95
  # maker = 'Maker: '+str(df_train.iloc[knn.indices.item(), 0])
 
97
  # vehicle = str(df_train.iloc[knn.indices.item(), 2])
98
  # year = str(df_train.iloc[knn.indices.item(), 3])
99
  # part = 'Part: '+str(df_train.iloc[knn.indices.item(), 4])
100
+ maker = 'Maker: ' + str(df_train.iloc[knn, 0])
101
  model = str(df_train.iloc[knn, 1])
102
+ if model == 'nan':
103
+ model = 'Model: No information'
104
  else:
105
+ model = 'Model: ' + model
106
  vehicle = str(df_train.iloc[knn, 2])
107
+ if vehicle == 'nan':
108
+ vehicle = 'Vehicle: No information'
109
  else:
110
+ vehicle = 'Vehicle: ' + vehicle
111
  year = str(df_train.iloc[knn, 3])
112
+ if year == 'nan':
113
+ year = 'Year: No information'
114
  else:
115
+ year = 'Year: ' + year
116
+ part = 'Part: ' + str(df_train.iloc[knn, 4])
117
+ predict_time = 'Predict time: ' + str(round(time.time() - start, 4)) + ' seconds'
118
 
119
  # Similarity score
120
+ sim_score = 'Confidence: ' + str(round(pred.max(axis=1).item() * 100, 2)) + '%'
121
 
122
  else:
123
  maker = 'This is not car part !'
124
+ model = vehicle = year = part = predict_time = sim_score = None
125
+
126
+ return {'maker': maker, 'model': model, 'vehicle': vehicle, 'year': year, 'part': part,
127
+ 'predict_time': predict_time, 'sim_score': sim_score}
128
 
 
129
 
130
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
131
  def get_classification_frame(image_tensor, df_train, sub_test_list, embeddings,
132
+ ort_session, input_name
133
+ ):
134
+ # ort_inputs = {input_name: to_numpy(image_tensor)}
 
135
  ort_inputs = {input_name: image_tensor}
136
  pred, em = ort_session.run(None, ort_inputs)
137
 
138
  if pred.max(axis=1) > args.VIDEO_CONFIDENCE:
139
  # knn = torch.nn.CosineSimilarity(dim = 1)(torch.tensor(em), embeddings).topk(1, largest=True)
140
  # part = str(df_train.iloc[knn.indices.item(), 4])
141
+ knn = np.array(
142
+ [dot((em), embeddings[i]) / (norm(em) * norm(embeddings[i])) for i in range(embeddings.shape[0])]).flatten()
143
  knn = np.argsort(knn)[-1]
144
  part = str(df_train.iloc[knn, 4])
145
  # Similarity score
146
+ sim_score = str(round(pred.max(axis=1).item() * 100, 2)) + '%'
147
  else:
148
  part = 'No part detected'
149
  sim_score = ''
150
 
151
+ return {'part_name': part, 'sim_score': sim_score}
152
+
153
 
154
  # predict similarity
155
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
156
  def get_similarity(image_tensor, df_train, sub_test_list, embeddings,
157
+ ort_session, input_name
158
+ ):
159
  start = time.time()
160
+ # ort_inputs = {input_name: to_numpy(image_tensor)}
161
  ort_inputs = {input_name: image_tensor}
162
  pred, em = ort_session.run(None, ort_inputs)
163
 
164
  # Compute kNN (using Cosine)
165
+ # knn = torch.nn.CosineSimilarity(dim = 1)(torch.tensor(em), embeddings).topk(6, largest=True)
166
  # idx = knn.indices.numpy()
167
+ knn = np.array(
168
+ [dot((em), embeddings[i]) / (norm(em) * norm(embeddings[i])) for i in range(embeddings.shape[0])]).flatten()
169
  idx = np.argsort(knn)[-6:]
170
+ predict_time = 'Predict time: ' + str(round(time.time() - start, 4)) + ' seconds'
171
  images_path = 'Test_set'
172
  images = [os.path.join(images_path, sub_test_list[i]) for i in idx]
173
  # sub_test_list
174
+ return {'images': images, 'predict_time': predict_time}
175
 
176
 
177
  # --------------------------------------------------------------------------------------------
 
179
  # --------------------------------------------------------------------------------------------
180
 
181
  content_images_dict = {
182
+ name: os.path.join(args.IMAGES_PATH, filee) for name, filee in
183
+ zip(args.CONTENT_IMAGES_NAME, args.CONTENT_IMAGES_FILE)
184
  }
185
 
186
 
 
187
  def show_original():
 
188
  """ Show Uploaded or Example image before prediction
189
 
190
  Returns:
 
193
  path to image
194
  """
195
 
196
+ if st.sidebar.checkbox('Upload', value=True, help='Select Upload to browse image from local machine'):
197
  content_file = st.sidebar.file_uploader("", type=["png", "jpg", "jpeg"])
198
  else:
199
  content_name = st.sidebar.selectbox("or Choose an example Image below", args.CONTENT_IMAGES_NAME)
 
201
 
202
  col1, col2 = st.columns(2)
203
  with col1:
204
+ # col1.markdown('## Target image')
205
  if content_file:
206
  col1.write('')
207
  col1.image(content_file, channels='BGR', width=300, clamp=True, caption='Input image')
 
210
 
211
 
212
  def image_input(content_file, df_train, sub_test_list, embeddings, ort_session, input_name, col2):
 
213
  # Set confidence level
214
  confidence_threshold = st.slider(
215
  "Confidence threshold", 0.0, 1.0, args.DEFAULT_CONFIDENCE_THRESHOLD, 0.05,
 
236
  if col7.button("SEARCH SIMILAR"):
237
  print_classification(col2, content_file, pred_info)
238
 
239
+ if pred_info['maker'] != 'This is not car part !':
240
+ # container = st.container()
241
+ print_similar_img(pred_images) # , container)
242
  else:
243
  st.warning("No similar car part image ! Reduce confidence threshold OR Choose another image.")
244
  else:
 
282
 
283
  def recv(self, frame: av.VideoFrame) -> av.VideoFrame:
284
  image = frame.to_ndarray(format="bgr24")
285
+ content = transform_image(image, IMG=False)
286
  pred_info = get_classification_frame(
287
  content, df_train, sub_test_list,
288
  embeddings, ort_session, input_name
 
301
 
302
 
303
  def print_classification(col2, content_file, pred_info):
 
304
  """ Print classification prediction
305
  """
306
 
307
  with col2:
308
  col2.markdown('### Predicted information')
309
  col2.markdown('')
310
+ if pred_info['maker'] != 'This is not car part !':
311
  col2.markdown('#### - {}'.format(pred_info['maker']))
312
  col2.markdown('#### - {}'.format(pred_info['model']))
313
  col2.markdown('#### - {}'.format(pred_info['vehicle']))
 
318
  else:
319
  col2.markdown('### {}'.format(pred_info['maker']))
320
 
 
321
 
322
+ def print_similar_img(pred_images):
323
  """ Print similarity images prediction
324
  """
325
 
 
328
 
329
  col3, col4, col5 = st.columns(3)
330
  with col3:
331
+ col3.image(pred_images['images'][0], channels='BGR', clamp=True, width=300)
332
+ col3.image(pred_images['images'][1], channels='BGR', clamp=True, width=300)
333
 
334
  with col4:
335
+ # col4.markdown('# ')
336
+ col4.image(pred_images['images'][3], channels='BGR', clamp=True, width=300)
337
+ col4.image(pred_images['images'][4], channels='BGR', clamp=True, width=300)
338
  with col5:
339
+ col5.image(pred_images['images'][5], channels='BGR', clamp=True, width=300)
340
+ col5.image(pred_images['images'][2], channels='BGR', clamp=True, width=300)