evelyncsb commited on
Commit
de33293
1 Parent(s): 50fc40e
app.py CHANGED
@@ -10,6 +10,7 @@ from imagebind import data
10
  from imagebind.models import imagebind_model
11
  from imagebind.models.imagebind_model import ModalityType
12
  import torch.nn as nn
 
13
 
14
 
15
  device = "cpu" #"cuda:0" if torch.cuda.is_available() else "cpu"
@@ -17,9 +18,29 @@ model = imagebind_model.imagebind_huge(pretrained=True)
17
  model.eval()
18
  model.to(device)
19
 
 
 
 
20
 
21
  def generate_image(text):
22
- return Image.open("./assets/ICA-Logo.png").convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Interface do Gradio
25
  iface = gr.Interface(
 
10
  from imagebind.models import imagebind_model
11
  from imagebind.models.imagebind_model import ModalityType
12
  import torch.nn as nn
13
+ import pickle
14
 
15
 
16
  device = "cpu" #"cuda:0" if torch.cuda.is_available() else "cpu"
 
18
  model.eval()
19
  model.to(device)
20
 
21
+ image_features = pickle.load(open("./assets/image_features.pkl","rb"))
22
+ image_paths = pickle.load(open("./assets/image_paths.pkl","rb"))
23
+
24
 
25
  def generate_image(text):
26
+ inputs = {
27
+ ModalityType.TEXT: data.load_and_transform_text([text], device)
28
+ }
29
+
30
+ with torch.no_grad():
31
+ embeddings = model(inputs)
32
+
33
+ text_features = embeddings[ModalityType.TEXT]
34
+ text_features /= text_features.norm(dim=-1, keepdim=True)
35
+
36
+ similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
37
+
38
+ #pega index maior
39
+ index_img = np.argmax(similarity)
40
+ img_name = os.path.basename(image_paths[index_img])
41
+ im = Image.open(f"./assets/images/{img_name}").convert("RGB")
42
+
43
+ return im
44
 
45
  # Interface do Gradio
46
  iface = gr.Interface(
assets/image_paths.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:512b07acbc5f8a2b02a78e8dbb417f86ee68fa87906101aa9b15c527b437b818
3
+ size 435
assets/images/astronaut.png ADDED
assets/images/camera.png ADDED
assets/images/chelsea.png ADDED
assets/images/coffee.png ADDED
assets/images/horse.png ADDED
assets/images/motorcycle_right.png ADDED
assets/images/page.png ADDED
assets/images/rocket.jpg ADDED
requirements.txt CHANGED
@@ -13,4 +13,5 @@ numpy>=1.19
13
  matplotlib
14
  types-regex
15
  mayavi
16
- scikit-image
 
 
13
  matplotlib
14
  types-regex
15
  mayavi
16
+ scikit-image
17
+ pickle
utils.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
 
2
 
 
3
  def calculate_similarity():
4
  image_features = embeddings[ModalityType.VISION]
5
  text_features = embeddings[ModalityType.TEXT]
 
1
  import os
2
+ import pickle
3
 
4
+ # teste = pickle.load(open("df_to_search.pkl","rb"))
5
  def calculate_similarity():
6
  image_features = embeddings[ModalityType.VISION]
7
  text_features = embeddings[ModalityType.TEXT]