pedramyazdipoor
commited on
Commit
•
7730914
1
Parent(s):
2963ed6
Update README.md
Browse files
README.md
CHANGED
@@ -54,7 +54,7 @@ There are some considerations for inference:
|
|
54 |
3) The selected span must be the most probable choice among N pairs of candidates.
|
55 |
|
56 |
```python
|
57 |
-
def generate_indexes(start_logits, end_logits, N,
|
58 |
|
59 |
output_start = start_logits
|
60 |
output_end = end_logits
|
@@ -71,7 +71,7 @@ def generate_indexes(start_logits, end_logits, N, min_index_list):
|
|
71 |
for a in range(0,N):
|
72 |
for b in range(0,N):
|
73 |
if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
|
74 |
-
if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (
|
75 |
prob = sorted_start_list[a][1] + sorted_end_list[b][1]
|
76 |
start_idx = sorted_start_list[a][0]
|
77 |
end_idx = sorted_end_list[b][0]
|
@@ -97,7 +97,7 @@ encoding = tokenizer(text,question,add_special_tokens = True,
|
|
97 |
|
98 |
out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
|
99 |
#we had to change some pieces of code to make it compatible with one answer generation at a time.
|
100 |
-
#you can initialize
|
101 |
start_index, end_index = generate_indexes(out['start_logits'][0], out['end_logits'][0], 5, 0)
|
102 |
print(tokenizer.tokenize(text + question)[start_index:end_index+1])
|
103 |
>>> ['اسمم', 'پدرام', '##ه', '.', 'اسمم', 'چیه', '؟']
|
|
|
54 |
3) The selected span must be the most probable choice among N pairs of candidates.
|
55 |
|
56 |
```python
|
57 |
+
def generate_indexes(start_logits, end_logits, N, max_index):
|
58 |
|
59 |
output_start = start_logits
|
60 |
output_end = end_logits
|
|
|
71 |
for a in range(0,N):
|
72 |
for b in range(0,N):
|
73 |
if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
|
74 |
+
if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (sorted_end_list[a][0] < max_index) :
|
75 |
prob = sorted_start_list[a][1] + sorted_end_list[b][1]
|
76 |
start_idx = sorted_start_list[a][0]
|
77 |
end_idx = sorted_end_list[b][0]
|
|
|
97 |
|
98 |
out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
|
99 |
#we had to change some pieces of code to make it compatible with one answer generation at a time.
|
100 |
+
#you can initialize max_index in generate_indexes() to put force on tokens being chosen to be within the context(end index must be less than seperator token).
|
101 |
start_index, end_index = generate_indexes(out['start_logits'][0], out['end_logits'][0], 5, 0)
|
102 |
print(tokenizer.tokenize(text + question)[start_index:end_index+1])
|
103 |
>>> ['اسمم', 'پدرام', '##ه', '.', 'اسمم', 'چیه', '؟']
|