vinayakdev commited on
Commit
8ee4f89
1 Parent(s): 8a463ae

Revert run_model

Browse files
Files changed (1) hide show
  1. generator.py +10 -20
generator.py CHANGED
@@ -34,34 +34,29 @@ import streamlit as st
34
  # hfmodel = pickle.load(open('models/hfmodel.sav', 'rb'))
35
 
36
  def load_model():
37
- hfm = pickle.load(open('t5_model.sav','rb'))
38
  hft = T5TokenizerFast.from_pretrained("t5-base")
39
  model = pickle.load(open('electra_model.sav','rb'))
40
  tok = et.from_pretrained("mrm8488/electra-small-finetuned-squadv2")
41
  # return hfm, hft,tok, model
42
  return hfm, hft,tok, model
43
 
44
- hfmodel, hftokenizer,tok, model = load_model()
45
 
46
  def run_model(input_string, **generator_args):
47
  generator_args = {
48
  "max_length": 256,
49
  "num_beams": 4,
50
  "length_penalty": 1.5,
51
- "no_repeat_ngram_size": 2,
52
- "early_stopping": False,
53
  }
54
  # tokenizer = att.from_pretrained("ThomasSimonini/t5-end2end-question-generation")
55
- # output = nlp(input_string)
56
-
57
  input_string = "generate questions: " + input_string + " </s>"
58
-
59
- inputs = tokenize([input_string])
60
-
61
- res = hfmodel.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], **generator_args)
62
- output = hftokenizer.decode(res[0], skip_special_tokens=True)
63
- # output = output.split('</sep>')
64
- # output = [o.strip() for o in output[:-1]]
65
  return output
66
 
67
 
@@ -126,18 +121,13 @@ def read_file(filepath_name):
126
 
127
  def create_string_for_generator(context):
128
  gen_list = gen_question(context)
129
- return gen_list
130
 
131
  def creator(context):
132
  questions = create_string_for_generator(context)
133
- questions = questions.split('?')
134
  pairs = []
135
  for ques in questions:
136
- l = len(ques)
137
- if(l == 0):
138
- continue
139
- if ques[l-1] != '?':
140
- ques = ques + '?'
141
  pair = QA(ques,context)
142
  print(pair)
143
  pairs.append(pair)
 
34
  # hfmodel = pickle.load(open('models/hfmodel.sav', 'rb'))
35
 
36
  def load_model():
37
+ hfm = pickle.load(open('hfmodel.sav','rb'))
38
  hft = T5TokenizerFast.from_pretrained("t5-base")
39
  model = pickle.load(open('electra_model.sav','rb'))
40
  tok = et.from_pretrained("mrm8488/electra-small-finetuned-squadv2")
41
  # return hfm, hft,tok, model
42
  return hfm, hft,tok, model
43
 
44
+ hfmodel, hftokenizer, tok, model = load_model()
45
 
46
  def run_model(input_string, **generator_args):
47
  generator_args = {
48
  "max_length": 256,
49
  "num_beams": 4,
50
  "length_penalty": 1.5,
51
+ "no_repeat_ngram_size": 3,
52
+ "early_stopping": True,
53
  }
54
  # tokenizer = att.from_pretrained("ThomasSimonini/t5-end2end-question-generation")
 
 
55
  input_string = "generate questions: " + input_string + " </s>"
56
+ input_ids = hftokenizer.encode(input_string, return_tensors="pt")
57
+ res = hfmodel.generate(input_ids, **generator_args)
58
+ output = hftokenizer.batch_decode(res, skip_special_tokens=True)
59
+ output = [item.split("<sep>") for item in output]
 
 
 
60
  return output
61
 
62
 
 
121
 
122
  def create_string_for_generator(context):
123
  gen_list = gen_question(context)
124
+ return (gen_list[0][0]).split('? ')
125
 
126
  def creator(context):
127
  questions = create_string_for_generator(context)
128
+ # questions = questions.split('?')
129
  pairs = []
130
  for ques in questions:
 
 
 
 
 
131
  pair = QA(ques,context)
132
  print(pair)
133
  pairs.append(pair)