ajanz commited on
Commit
66d0fee
1 Parent(s): fb53c32

an extended tokenizing function (as it was proposed in source project)

Browse files
Files changed (1) hide show
  1. app.py +49 -5
app.py CHANGED
@@ -23,6 +23,49 @@ textbox = gr.Textbox(
23
  )
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def load_index(index_data: str = "clarin-knext/entity-linking-index"):
27
  ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
28
  index_data = {
@@ -44,7 +87,8 @@ model = load_model()
44
  index = load_index()
45
 
46
 
47
- def predict(query: str = sample_text, top_k: int=3):
 
48
  index_data, faiss_index = index
49
  # takes only the [CLS] embedding (for now)
50
  query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
@@ -52,13 +96,13 @@ def predict(query: str = sample_text, top_k: int=3):
52
  scores, indices = faiss_index.search(query, top_k)
53
  scores, indices = scores.tolist(), indices.tolist()
54
 
55
- results = [
56
- (index_data[result[0]], result[1])
57
  for output in zip(indices, scores)
58
  for result in zip(*output)
59
- ]
60
 
61
- return str(results)
62
 
63
 
64
  demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()
 
23
  )
24
 
25
 
26
+ def prepare_query(tokenizer, query, max_seq_length=300):
27
+ # temporary solution
28
+ mention_start_token: str = "[unused0]"
29
+ mention_end_token: str = "[unused1]"
30
+
31
+ left_context = query.split(mention_start_token)[0]
32
+ right_context = query.split(mention_end_token)[-1]
33
+ mention = query.split(mention_start_token)[-1].split(mention_end_token)[0]
34
+
35
+ mention_ids = tokenizer(
36
+ mention_start_token + mention + mention_end_token,
37
+ add_special_tokens=False
38
+ )['input_ids']
39
+
40
+ left_ids = tokenizer(left_context, add_special_tokens=False)['input_ids']
41
+ left_quota = (max_seq_length - len(mention_ids)) // 2 - 1
42
+
43
+ right_ids = tokenizer(right_context, add_special_tokens=False)['input_ids']
44
+ right_quota = max_seq_length - len(mention_ids) - left_quota - 2
45
+
46
+ left_add, right_add = len(left_ids), len(right_ids)
47
+ if left_add <= left_quota:
48
+ right_quota += left_quota - left_add if right_add > right_quota else 0
49
+ else:
50
+ left_quota += right_quota - right_add if right_add <= right_quota else 0
51
+
52
+ context_ids = [
53
+ tokenizer.cls_token_id,
54
+ *left_ids[-left_quota:],
55
+ *mention_ids,
56
+ *right_ids[:right_quota],
57
+ tokenizer.sep_token_id
58
+ ]
59
+
60
+ padding_length = max_seq_length - len(context_ids)
61
+ # attention_mask = [1] * len(context_ids) + [0] * padding_length
62
+
63
+ context_ids += [tokenizer.pad_token_id] * padding_length
64
+
65
+ assert len(context_ids) == max_seq_length
66
+ return context_ids
67
+
68
+
69
  def load_index(index_data: str = "clarin-knext/entity-linking-index"):
70
  ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
71
  index_data = {
 
87
  index = load_index()
88
 
89
 
90
+ def predict(text: str = sample_text, top_k: int=3):
91
+ query = prepare_query(text)
92
  index_data, faiss_index = index
93
  # takes only the [CLS] embedding (for now)
94
  query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
 
96
  scores, indices = faiss_index.search(query, top_k)
97
  scores, indices = scores.tolist(), indices.tolist()
98
 
99
+ results = "\n".join([
100
+ f"{index_data[result[0]]}: {result[1]}"
101
  for output in zip(indices, scores)
102
  for result in zip(*output)
103
+ ])
104
 
105
+ return results
106
 
107
 
108
  demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()