flyboytarantino14 commited on
Commit
ef98ab3
1 Parent(s): e9da942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -4,10 +4,11 @@ import gradio as gr
4
  from transformers import T5ForConditionalGeneration, T5Tokenizer
5
  #question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
6
  #question_tokenizer = T5Tokenizer.from_pretrained('t5-base')
7
- question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
8
- question_tokenizer = T5Tokenizer.from_pretrained('t5-small')
9
 
10
  def get_question(context, answer):
 
 
11
  text = "context: {} answer: {}".format(context, answer)
12
  #max_len = 512
13
  #encoding = question_tokenizer.encode_plus(text, max_length=max_len, padding='max_length', truncation=True, return_tensors="pt")
@@ -19,7 +20,7 @@ def get_question(context, answer):
19
  num_beams=3, # Use fewer beams to generate fewer but higher-quality questions
20
  num_return_sequences=3,
21
  no_repeat_ngram_size=3, # Allow some repetition to avoid generating nonsensical questions
22
- max_length=128) # Use a shorter max length to focus on generating more relevant questions
23
  dec = [question_tokenizer.decode(ids) for ids in outs]
24
  questions = ""
25
  for i, question in enumerate(dec):
 
4
  from transformers import T5ForConditionalGeneration, T5Tokenizer
5
  #question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
6
  #question_tokenizer = T5Tokenizer.from_pretrained('t5-base')
7
+
 
8
 
9
  def get_question(context, answer):
10
+ question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
11
+ question_tokenizer = T5Tokenizer.from_pretrained('t5-small')
12
  text = "context: {} answer: {}".format(context, answer)
13
  #max_len = 512
14
  #encoding = question_tokenizer.encode_plus(text, max_length=max_len, padding='max_length', truncation=True, return_tensors="pt")
 
20
  num_beams=3, # Use fewer beams to generate fewer but higher-quality questions
21
  num_return_sequences=3,
22
  no_repeat_ngram_size=3, # Allow some repetition to avoid generating nonsensical questions
23
+ max_length=256) # Use a shorter max length to focus on generating more relevant questions
24
  dec = [question_tokenizer.decode(ids) for ids in outs]
25
  questions = ""
26
  for i, question in enumerate(dec):