Tahsin-Mayeesha commited on
Commit
2f00cc5
1 Parent(s): 90b8c65
Files changed (1) hide show
  1. app.py +28 -9
app.py CHANGED
@@ -14,7 +14,7 @@ def choose_model(model_choice):
14
  return "jannatul17/squad-bn-qgen-banglat5-v1"
15
 
16
 
17
- def generate__questions(model_choice,context,answer):
18
  model_name = choose_model(model_choice)
19
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -26,19 +26,38 @@ def generate__questions(model_choice,context,answer):
26
  generated_ids = model.generate(
27
  input_ids=text_encoding['input_ids'],
28
  attention_mask=text_encoding['attention_mask'],
29
- max_length=64,
30
- num_beams=5,
31
- num_return_sequences=1
 
 
 
 
32
  )
33
 
34
- return tokenizer.decode(generated_ids[0],skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' ')
35
-
36
- demo = gr.Interface(fn=generate__questions, inputs=[gr.Dropdown(label="Model", choices=["mt5-small","mt5-base","banglat5"],value="banglat5"),
 
 
 
 
 
37
  gr.Textbox(label='Context'),
38
- gr.Textbox(label='Answer')] ,
 
 
 
 
 
 
 
 
 
 
39
  outputs=gr.Textbox(label='Question'),
40
  examples=[["banglat5",example_context,example_answer]],
41
  cache_examples=False,
42
  title="Bangla Question Generation",
43
  description="Get the Question from given Context and an Answer")
44
- demo.launch()
 
14
  return "jannatul17/squad-bn-qgen-banglat5-v1"
15
 
16
 
17
+ def generate_questions(model_choice,context,answer,numReturnSequences=1,num_beams=None,do_sample=False,top_p=None,top_k=None,temperature=None):
18
  model_name = choose_model(model_choice)
19
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
26
  generated_ids = model.generate(
27
  input_ids=text_encoding['input_ids'],
28
  attention_mask=text_encoding['attention_mask'],
29
+ max_length=120,
30
+ num_beams=num_beams,
31
+ do_sample=do_sample,
32
+ top_k = top_k,
33
+ top_p = top_p,
34
+ temperature = temperature,
35
+ num_return_sequences=numReturnSequences
36
  )
37
 
38
+ text = []
39
+ for id in generated_ids:
40
+ text.append(tokenizer.decode(id,skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' '))
41
+
42
+ return " ".join(text)
43
+
44
+
45
+ demo = gr.Interface(fn=generate_questions, inputs=[gr.Dropdown(label="Model", choices=["mt5-small","mt5-base","banglat5"],value="banglat5"),
46
  gr.Textbox(label='Context'),
47
+ gr.Textbox(label='Answer'),
48
+ # hyperparameters
49
+ gr.Slider(1, 3, 1, step=1, label="Num return Sequences"),
50
+ # beam search
51
+ gr.Slider(1, 10,value=None, step=1, label="Beam width"),
52
+ # top-k/top-p
53
+ gr.Checkbox(label="Do Random Sample",value=False),
54
+ gr.Slider(0, 50, value=None, step=1, label="Top K"),
55
+ gr.Slider(0, 1, value=None, label="Top P/Nucleus Sampling"),
56
+ gr.Slider(0.01, 1, value=None, label="Temperature") ] ,
57
+ # output
58
  outputs=gr.Textbox(label='Question'),
59
  examples=[["banglat5",example_context,example_answer]],
60
  cache_examples=False,
61
  title="Bangla Question Generation",
62
  description="Get the Question from given Context and an Answer")
63
+ demo.launch()