seonglae commited on
Commit
1a11e20
1 Parent(s): ad19f21

fix: summaries list argument error issue

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. model.py +3 -1
app.py CHANGED
@@ -76,7 +76,7 @@ def main(question: str):
76
  ctx = '\n'.join(texts)
77
 
78
  # Reader
79
- summary = summarize_text(models['summarizer'][0],
80
  models['summarizer'][1], [ctx])
81
  answers = ask_reader(models['reader'][0],
82
  models['reader'][1], [question], [summary])
 
76
  ctx = '\n'.join(texts)
77
 
78
  # Reader
79
+ [summary] = summarize_text(models['summarizer'][0],
80
  models['summarizer'][1], [ctx])
81
  answers = ask_reader(models['reader'][0],
82
  models['reader'][1], [question], [summary])
model.py CHANGED
@@ -28,7 +28,7 @@ def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditiona
28
  return summaries
29
 
30
 
31
- def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
32
  tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
33
  model = PegasusXForConditionalGeneration.from_pretrained(model_id)
34
  if cuda:
@@ -58,6 +58,8 @@ def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
58
  model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len)
59
  answer_infos = pipeline(
60
  question=questions, context=ctxs)
 
 
61
  for answer_info in answer_infos:
62
  answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
63
  return answer_infos
 
28
  return summaries
29
 
30
 
31
+ def get_summarizer(model_id="seonglae/resrer-pegasus-x") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
32
  tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
33
  model = PegasusXForConditionalGeneration.from_pretrained(model_id)
34
  if cuda:
 
58
  model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len)
59
  answer_infos = pipeline(
60
  question=questions, context=ctxs)
61
+ if not isinstance(answer_infos, list):
62
+ answer_infos = [answer_infos]
63
  for answer_info in answer_infos:
64
  answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
65
  return answer_infos