flyboytarantino14 commited on
Commit
8024d68
1 Parent(s): 1dee0bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -52,28 +52,30 @@ question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_s
52
  question_tokenizer = T5Tokenizer.from_pretrained('t5-base')
53
 
54
  def get_question(sentence,answer):
55
- text = "context: {} answer: {} </s>".format(sentence,answer)
56
- print (text)
57
- max_len = 256
58
- encoding = question_tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=True, return_tensors="pt")
59
-
60
- input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
61
-
62
- outs = question_model.generate(input_ids=input_ids,
63
  attention_mask=attention_mask,
64
  early_stopping=True,
65
  num_beams=3,
66
  num_return_sequences=3,
67
  no_repeat_ngram_size=2,
68
  max_length=200)
69
-
70
-
71
- dec = [question_tokenizer.decode(ids) for ids in outs]
72
-
73
- #Question = dec[0].replace("question:","")
74
- #Question= Question.strip()
75
- #return Question
76
- return dec
 
 
77
 
78
  input_context = gr.Textbox()
79
  input_answer = gr.Textbox()
 
52
  question_tokenizer = T5Tokenizer.from_pretrained('t5-base')
53
 
54
  def get_question(sentence,answer):
55
+ text = "context: {} answer: {} </s>".format(sentence,answer)
56
+ print (text)
57
+ max_len = 256
58
+ encoding = question_tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=True, return_tensors="pt")
59
+
60
+ input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
61
+
62
+ outs = question_model.generate(input_ids=input_ids,
63
  attention_mask=attention_mask,
64
  early_stopping=True,
65
  num_beams=3,
66
  num_return_sequences=3,
67
  no_repeat_ngram_size=2,
68
  max_length=200)
69
+
70
+ dec = [question_tokenizer.decode(ids) for ids in outs]
71
+ questions = ""
72
+ for i, question in enumerate(dec):
73
+ question = question.replace("question:", "").replace("<pad>", "").replace("</s>", "")
74
+ question = question.strip()
75
+ questions += question
76
+ if i != len(dec)-1:
77
+ questions += "§"
78
+ return questions
79
 
80
  input_context = gr.Textbox()
81
  input_answer = gr.Textbox()