flyboytarantino14 commited on
Commit
ac436d4
1 Parent(s): 84710d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -13,13 +13,22 @@ def get_question(context, answer):
13
  #encoding = question_tokenizer.encode_plus(text, max_length=max_len, padding='max_length', truncation=True, return_tensors="pt")
14
  encoding = question_tokenizer.encode_plus(text, return_tensors="pt")
15
  input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
 
 
 
 
 
 
 
 
16
  outs = question_model.generate(input_ids=input_ids,
17
- attention_mask=attention_mask,
18
- early_stopping=True,
19
- num_beams=3, # Use fewer beams to generate fewer but higher-quality questions
20
- num_return_sequences=3,
21
- no_repeat_ngram_size=3, # Allow some repetition to avoid generating nonsensical questions
22
- max_length=256) # Use a shorter max length to focus on generating more relevant questions
 
23
  dec = [question_tokenizer.decode(ids) for ids in outs]
24
  questions = ""
25
  for i, question in enumerate(dec):
 
13
  #encoding = question_tokenizer.encode_plus(text, max_length=max_len, padding='max_length', truncation=True, return_tensors="pt")
14
  encoding = question_tokenizer.encode_plus(text, return_tensors="pt")
15
  input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
16
+ #outs = question_model.generate(input_ids=input_ids,
17
+ # attention_mask=attention_mask,
18
+ # early_stopping=True,
19
+ # num_beams=3, # Use fewer beams to generate fewer but higher-quality questions
20
+ # num_return_sequences=3,
21
+ # no_repeat_ngram_size=3, # Allow some repetition to avoid generating nonsensical questions
22
+ # max_length=256) # Use a shorter max length to focus on generating more relevant questions
23
+
24
  outs = question_model.generate(input_ids=input_ids,
25
+ attention_mask=attention_mask,
26
+ early_stopping=True,)
27
+
28
+
29
+
30
+
31
+
32
  dec = [question_tokenizer.decode(ids) for ids in outs]
33
  questions = ""
34
  for i, question in enumerate(dec):