pedramyazdipoor commited on
Commit
7730914
1 Parent(s): 2963ed6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -54,7 +54,7 @@ There are some considerations for inference:
54
  3) The selected span must be the most probable choice among N pairs of candidates.
55
 
56
  ```python
57
- def generate_indexes(start_logits, end_logits, N, min_index_list):
58
 
59
  output_start = start_logits
60
  output_end = end_logits
@@ -71,7 +71,7 @@ def generate_indexes(start_logits, end_logits, N, min_index_list):
71
  for a in range(0,N):
72
  for b in range(0,N):
73
  if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
74
- if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (sorted_start_list[a][0] > min_index_list) :
75
  prob = sorted_start_list[a][1] + sorted_end_list[b][1]
76
  start_idx = sorted_start_list[a][0]
77
  end_idx = sorted_end_list[b][0]
@@ -97,7 +97,7 @@ encoding = tokenizer(text,question,add_special_tokens = True,
97
 
98
  out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
99
  #we had to change some pieces of code to make it compatible with one answer generation at a time.
100
- #you can initialize min_index_list in generate_indexes() to put force on tokens being chosen to be within the context(start index must be bigger than the seperator token.
101
  start_index, end_index = generate_indexes(out['start_logits'][0], out['end_logits'][0], 5, 0)
102
  print(tokenizer.tokenize(text + question)[start_index:end_index+1])
103
  >>> ['اسمم', 'پدرام', '##ه', '.', 'اسمم', 'چیه', '؟']
 
54
  3) The selected span must be the most probable choice among N pairs of candidates.
55
 
56
  ```python
57
+ def generate_indexes(start_logits, end_logits, N, max_index):
58
 
59
  output_start = start_logits
60
  output_end = end_logits
 
71
  for a in range(0,N):
72
  for b in range(0,N):
73
  if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
74
+ if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (sorted_end_list[a][0] < max_index) :
75
  prob = sorted_start_list[a][1] + sorted_end_list[b][1]
76
  start_idx = sorted_start_list[a][0]
77
  end_idx = sorted_end_list[b][0]
 
97
 
98
  out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
99
  #we had to change some pieces of code to make it compatible with one answer generation at a time.
100
+ #you can initialize max_index in generate_indexes() to put force on tokens being chosen to be within the context(end index must be less than seperator token).
101
  start_index, end_index = generate_indexes(out['start_logits'][0], out['end_logits'][0], 5, 0)
102
  print(tokenizer.tokenize(text + question)[start_index:end_index+1])
103
  >>> ['اسمم', 'پدرام', '##ه', '.', 'اسمم', 'چیه', '؟']