Heiko Hotz commited on
Commit
4b9c730
1 Parent(s): 833c58b

initial commit

Browse files
Files changed (1) hide show
  1. predict.py +3 -3
predict.py CHANGED
@@ -13,8 +13,8 @@ from transformers import (
13
  from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
14
  from transformers.data.metrics.squad_metrics import compute_predictions_logits
15
 
 
16
  def run_prediction(question_texts, context_text, model_path, n_best_size=1):
17
- ### Setting hyperparameters
18
  max_seq_length = 512
19
  doc_stride = 256
20
  n_best_size = n_best_size
@@ -102,7 +102,7 @@ def run_prediction(question_texts, context_text, model_path, n_best_size=1):
102
  print(all_results)
103
 
104
  output_nbest_file = None
105
- if n_best_size > 1:
106
  output_nbest_file = "nbest.json"
107
 
108
  timer = time.time()
@@ -123,4 +123,4 @@ def run_prediction(question_texts, context_text, model_path, n_best_size=1):
123
  )
124
  print(f'Logits converted to predictions in {time.time()-timer} seconds')
125
 
126
- return final_predictions
 
13
  from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
14
  from transformers.data.metrics.squad_metrics import compute_predictions_logits
15
 
16
+
17
  def run_prediction(question_texts, context_text, model_path, n_best_size=1):
 
18
  max_seq_length = 512
19
  doc_stride = 256
20
  n_best_size = n_best_size
 
102
  print(all_results)
103
 
104
  output_nbest_file = None
105
+ if int(n_best_size) > 1:
106
  output_nbest_file = "nbest.json"
107
 
108
  timer = time.time()
 
123
  )
124
  print(f'Logits converted to predictions in {time.time()-timer} seconds')
125
 
126
+ return final_predictions