Toonies commited on
Commit
4dc1f91
1 Parent(s): 6ae99b6

fix error return

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -30,28 +30,38 @@ def embedding_input(text_input):
30
  return text_emb
31
 
32
  def embedding_img():
33
- global images
34
- img_batch = imagenette['image']
35
-
36
- images = processor(
37
- text = None,
38
- images = img_batch,
39
- return_tensors = 'pt'
40
- )['pixel_values'].to(device)
41
- batch_emb = model.get_image_features(pixel_values =img_batch)
42
- batch_emb = batch_emb.squeeze(0)
43
- image_arr = batch_emb.cpu().detach().numpy()
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return image_arr
46
 
47
  def norm_val(text_input):
48
- image_arr = embedding_img()
49
- time.sleep(5)
50
  text_emb = embedding_input(text_input)
51
-
52
- image_arr = (image_arr.T / np.linalg.norm(image_arr, axis = 1)).T
53
  text_emb = text_emb.cpu().detach().numpy()
54
- scores = np.dot(text_emb, image_arr.T)
55
  top_k = 1
56
  idx = np.argsort(-scores[0])[:top_k]
57
  return images[idx[0]]
@@ -61,6 +71,7 @@ def norm_val(text_input):
61
 
62
 
63
  if __name__ == "__main__":
 
64
  load_data()
65
  iface = gr.Interface(fn=norm_val, inputs="text", outputs="image")
66
  iface.launch(inline = False )
 
30
  return text_emb
31
 
32
  def embedding_img():
33
+ global images, image_arr
34
+ load_data()
35
+ sample_idx= np.random.randint(0, len(imagenette)+1, 100).tolist()
36
+ images = [imagenette[i]['image'] for i in sample_idx]
37
+ batch_sie = 5
38
+ image_arr = None
39
+ for i in tqdm(range(0, len(images), batch_sie)):
40
+ time.sleep(1)
41
+ batch = images[i:i+batch_sie]
 
 
42
 
43
+ batch = processor(
44
+ text = None,
45
+ images = batch,
46
+ return_tensors= 'pt',
47
+ padding = True
48
+ )['pixel_values'].to(device)
49
+ batch_emb = model.get_image_features(pixel_values = batch)
50
+ batch_emb = batch_emb.squeeze(0)
51
+ batch_emb = batch_emb.cpu().detach().numpy()
52
+
53
+ if image_arr is None:
54
+ image_arr = batch_emb
55
+
56
+ else:
57
+ image_arr = np.concatenate((image_arr, batch_emb), axis = 0)
58
  return image_arr
59
 
60
  def norm_val(text_input):
 
 
61
  text_emb = embedding_input(text_input)
62
+ image_emb = (image_arr.T / np.linalg.norm(image_arr, axis = 1)).T
 
63
  text_emb = text_emb.cpu().detach().numpy()
64
+ scores = np.dot(text_emb, image_emb.T)
65
  top_k = 1
66
  idx = np.argsort(-scores[0])[:top_k]
67
  return images[idx[0]]
 
71
 
72
 
73
  if __name__ == "__main__":
74
+ embedding_img()
75
  load_data()
76
  iface = gr.Interface(fn=norm_val, inputs="text", outputs="image")
77
  iface.launch(inline = False )