mojtaba-nafez commited on
Commit
fd6aade
1 Parent(s): 1385d75

fix app.py to read from saved poem_embeddings.json

Browse files
Files changed (2) hide show
  1. app.py +8 -4
  2. inference.py +64 -8
app.py CHANGED
@@ -3,6 +3,7 @@ from inference import predict_poems_from_text
3
  from utils import get_poem_embeddings
4
  import config as CFG
5
  import json
 
6
  import gradio as gr
7
 
8
  def greet_user(name):
@@ -12,15 +13,18 @@ if __name__ == "__main__":
12
  model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
13
  model.eval()
14
  # Inference: Output some example predictions and write them in a file
15
- with open(CFG.dataset_path, encoding="utf-8") as f:
16
- dataset = json.load(f)
 
 
 
 
17
 
18
  def gradio_make_predictions(text):
19
- beyts = predict_poems_from_text(model, poem_embeddings, text, [data['beyt'] for data in dataset], n=10)
20
  return "\n".join(beyts)
21
 
22
  CFG.batch_size = 512
23
- model, poem_embeddings = get_poem_embeddings(dataset, model)
24
  # print(poem_embeddings[0])
25
  # with open('poem_embeddings.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
26
  # f.write(json.dumps(poem_embeddings, indent= 4))
 
3
  from utils import get_poem_embeddings
4
  import config as CFG
5
  import json
6
+ import torch
7
  import gradio as gr
8
 
9
  def greet_user(name):
 
13
  model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
14
  model.eval()
15
  # Inference: Output some example predictions and write them in a file
16
+ with open('poem_embeddings.json', encoding="utf-8") as f:
17
+ pe = json.load(f)
18
+
19
+ poem_embeddings = torch.Tensor([p['embeddings'] for p in pe]).to(CFG.device)
20
+ print(poem_embeddings.shape)
21
+ poems = [p['beyt'] for p in pe]
22
 
23
  def gradio_make_predictions(text):
24
+ beyts = predict_poems_from_text(model, poem_embeddings, text, poems, n=10)
25
  return "\n".join(beyts)
26
 
27
  CFG.batch_size = 512
 
28
  # print(poem_embeddings[0])
29
  # with open('poem_embeddings.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
30
  # f.write(json.dumps(poem_embeddings, indent= 4))
inference.py CHANGED
@@ -12,9 +12,10 @@ from models import PoemTextModel
12
  from utils import get_poem_embeddings
13
  import json
14
  import os
 
15
 
16
 
17
- def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10):
18
  """
19
  Returns n poems which are the most similar to a text query
20
 
@@ -32,6 +33,8 @@ def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer
32
  tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs.
33
  n: int, optional
34
  number of poems to return
 
 
35
 
36
  Returns:
37
  --------
@@ -63,11 +66,36 @@ def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer
63
  dot_similarity = text_embeddings_n @ poem_embeddings_n.T
64
 
65
  # returning top n poems based on embedding similarity
66
- _, indices = torch.topk(dot_similarity.squeeze(0), n)
67
- return [poems[idx] for idx in indices]
68
-
69
-
70
- def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
72
  Returns n poems which are the most similar to an image query
73
 
@@ -83,6 +111,8 @@ def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10
83
  poems corresponding to poem_embeddings
84
  n: int, optional
85
  number of poems to return
 
 
86
 
87
  Returns:
88
  --------
@@ -107,8 +137,34 @@ def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10
107
  dot_similarity = image_embeddings_n @ poem_embeddings_n.T
108
 
109
  # returning top n poems based on embedding similarity
110
- _, indices = torch.topk(dot_similarity.squeeze(0), n)
111
- return [poems[idx] for idx in indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  if __name__ == "__main__":
114
  """
 
12
  from utils import get_poem_embeddings
13
  import json
14
  import os
15
+ import regex
16
 
17
 
18
+ def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10, return_similarities=False):
19
  """
20
  Returns n poems which are the most similar to a text query
21
 
 
33
  tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs.
34
  n: int, optional
35
  number of poems to return
36
+ return_similarities: bool, optional
37
+ if True, a dictionary will be returned which has the poem beyts and their similarities to the text
38
 
39
  Returns:
40
  --------
 
66
  dot_similarity = text_embeddings_n @ poem_embeddings_n.T
67
 
68
  # returning top n poems based on embedding similarity
69
+ values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))
70
+
71
+ # since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
72
+ # so we must check the poems added to result not to be duplicates
73
+ def is_poem_duplicate(poem, poems):
74
+ poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
75
+ for other_poem in poems:
76
+ other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
77
+ if poem == other_poem:
78
+ return True
79
+ return False
80
+
81
+ results = []
82
+ computed_k = 0
83
+ for i in range(len(poems)):
84
+ if computed_k == n:
85
+ break
86
+ if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
87
+ results.append({
88
+ 'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
89
+ 'similarity': values[i]
90
+ })
91
+ computed_k += 1
92
+ if return_similarities:
93
+ return results
94
+ else:
95
+ return [res['beyt'] for res in results]
96
+
97
+
98
+ def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10, return_similarities=False):
99
  """
100
  Returns n poems which are the most similar to an image query
101
 
 
111
  poems corresponding to poem_embeddings
112
  n: int, optional
113
  number of poems to return
114
+ return_similarities: bool, optional
115
+ if True, a dictionary will be returned which has the poem beyts and their similarities to the text
116
 
117
  Returns:
118
  --------
 
137
  dot_similarity = image_embeddings_n @ poem_embeddings_n.T
138
 
139
  # returning top n poems based on embedding similarity
140
+ values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))
141
+
142
+ # since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
143
+ # so we must check the poems added to result not to be duplicates
144
+ def is_poem_duplicate(poem, poems):
145
+ poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
146
+ for other_poem in poems:
147
+ other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
148
+ if poem == other_poem:
149
+ return True
150
+ return False
151
+
152
+ results = []
153
+ computed_k = 0
154
+ for i in range(len(poems)):
155
+ if computed_k == n:
156
+ break
157
+ if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
158
+ results.append({
159
+ 'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
160
+ 'similarity': values[i]
161
+ })
162
+ computed_k += 1
163
+ if return_similarities:
164
+ return results
165
+ else:
166
+ return [res['beyt'] for res in results]
167
+
168
 
169
  if __name__ == "__main__":
170
  """