pedramyazdipoor commited on
Commit
ed4a795
1 Parent(s): 2f59c7f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -58,7 +58,7 @@ There are some considerations for inference:
58
  3) The selected span must be the most probable choice among N pairs of candidates.
59
 
60
  ```python
61
- def generate_indexes(start_logits, end_logits, N, min_index_list):
62
 
63
  output_start = start_logits
64
  output_end = end_logits
@@ -79,7 +79,7 @@ def generate_indexes(start_logits, end_logits, N, min_index_list):
79
  for a in range(0,N):
80
  for b in range(0,N):
81
  if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
82
- if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (sorted_start_list[a][0] > min_index_list) :
83
  prob = sorted_start_list[a][1] + sorted_end_list[b][1]
84
  start_idx = sorted_start_list[a][0]
85
  end_idx = sorted_end_list[b][0]
@@ -104,7 +104,7 @@ encoding = tokenizer(text,question,add_special_tokens = True,
104
  out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
105
  #we had to change some pieces of code to make it compatible with one answer generation at a time
106
  #If you have unanswerable questions, use out['start_logits'][0][0:] and out['end_logits'][0][0:] because <s> (the 1st token) is for this situation and must be compared with other tokens.
107
- #you can initialize min_index_list in generate_indexes() to put force on tokens being chosen to be within the context(start index must be bigger than the seperator token.
108
  answer_start_index, answer_end_index = generate_indexes(out['start_logits'][0][1:], out['end_logits'][0][1:], 5, 0)
109
  print(tokenizer.tokenize(text + question))
110
  print(tokenizer.tokenize(text + question)[answer_start_index : (answer_end_index + 1)])
 
58
  3) The selected span must be the most probable choice among N pairs of candidates.
59
 
60
  ```python
61
+ def generate_indexes(start_logits, end_logits, N, max_index):
62
 
63
  output_start = start_logits
64
  output_end = end_logits
 
79
  for a in range(0,N):
80
  for b in range(0,N):
81
  if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
82
+ if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (sorted_end_list[a][0] < max_index) :
83
  prob = sorted_start_list[a][1] + sorted_end_list[b][1]
84
  start_idx = sorted_start_list[a][0]
85
  end_idx = sorted_end_list[b][0]
 
104
  out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
105
  #we had to change some pieces of code to make it compatible with one answer generation at a time
106
  #If you have unanswerable questions, use out['start_logits'][0][0:] and out['end_logits'][0][0:] because <s> (the 1st token) is for this situation and must be compared with other tokens.
107
+ #you can initialize max_index in generate_indexes() to put force on tokens being chosen to be within the context(end index must be less than seperator token).
108
  answer_start_index, answer_end_index = generate_indexes(out['start_logits'][0][1:], out['end_logits'][0][1:], 5, 0)
109
  print(tokenizer.tokenize(text + question))
110
  print(tokenizer.tokenize(text + question)[answer_start_index : (answer_end_index + 1)])