Christian Koch commited on
Commit
32ee8bd
1 Parent(s): 0df07e9

fix missing t5 model

Browse files
Files changed (2) hide show
  1. app.py +8 -5
  2. question_generator.py +0 -33
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import streamlit as st
2
- from transformers import pipeline, PegasusForConditionalGeneration, PegasusTokenizer
3
  import nltk
4
 
5
  from fill_in_summary import FillInSummary
6
  from paraphrase import PegasusParaphraser
7
  import question_generator as q
8
 
 
 
 
9
 
10
  # Question Generator Variables
11
  ids = {'mt5-small': st.secrets['small'],
@@ -25,11 +28,11 @@ if select == "Question Generator":
25
  #st.selectbox('Model', ['T5', 'GPT Neo-X'])
26
 
27
  # Download all models from drive
28
- q.download_models(ids)
29
 
30
  # Model selection
31
  model_path = st.selectbox('', options=[k for k in ids], index=1, help='Model to use. ')
32
- model = q.load_model(model_path=f"model/{model_path}.ckpt")
33
 
34
  text_input = st.text_area("Input Text")
35
 
@@ -39,7 +42,7 @@ if select == "Question Generator":
39
 
40
  if split:
41
  # Split into sentences
42
- sent_tokenized = nltk.sent_tokenize(inputs)
43
  res = {}
44
 
45
  with st.spinner('Please wait while the inputs are being processed...'):
@@ -61,7 +64,7 @@ if select == "Question Generator":
61
  else:
62
  with st.spinner('Please wait while the inputs are being processed...'):
63
  # Prediction
64
- predictions = model.multitask([inputs], max_length=512)
65
  questions, answers, answers_bis = predictions['questions'], predictions['answers'], predictions[
66
  'answers_bis']
67
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, PegasusForConditionalGeneration, PegasusTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
3
  import nltk
4
 
5
  from fill_in_summary import FillInSummary
6
  from paraphrase import PegasusParaphraser
7
  import question_generator as q
8
 
9
+ tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
10
+
11
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
12
 
13
  # Question Generator Variables
14
  ids = {'mt5-small': st.secrets['small'],
 
28
  #st.selectbox('Model', ['T5', 'GPT Neo-X'])
29
 
30
  # Download all models from drive
31
+ # q.download_models(ids)
32
 
33
  # Model selection
34
  model_path = st.selectbox('', options=[k for k in ids], index=1, help='Model to use. ')
35
+
36
 
37
  text_input = st.text_area("Input Text")
38
 
 
42
 
43
  if split:
44
  # Split into sentences
45
+ sent_tokenized = nltk.sent_tokenize(text_input)
46
  res = {}
47
 
48
  with st.spinner('Please wait while the inputs are being processed...'):
 
64
  else:
65
  with st.spinner('Please wait while the inputs are being processed...'):
66
  # Prediction
67
+ predictions = model.multitask([text_input], max_length=512)
68
  questions, answers, answers_bis = predictions['questions'], predictions['answers'], predictions[
69
  'answers_bis']
70
 
question_generator.py CHANGED
@@ -9,39 +9,6 @@ from transformers import AutoTokenizer
9
  from mt5 import MT5
10
 
11
 
12
- def download_models(ids):
13
- """
14
- Download all models.
15
- :param ids: name and links of models
16
- :return:
17
- """
18
-
19
- # Download sentence tokenizer
20
- nltk.download('punkt')
21
-
22
- # Download model from drive if not stored locally
23
- for key in ids:
24
- if not os.path.isfile(f"model/{key}.ckpt"):
25
- url = f"https://drive.google.com/u/0/uc?id={ids[key]}"
26
- gdown.download(url=url, output=f"model/{key}.ckpt")
27
-
28
-
29
- @st.cache(allow_output_mutation=True)
30
- def load_model(model_path):
31
- """
32
- Load model and cache it.
33
- :param model_path: path to model
34
- :return:
35
- """
36
-
37
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
-
39
- # Loading model and tokenizer
40
- model = MT5.load_from_checkpoint(model_path).eval().to(device)
41
- model.tokenizer = AutoTokenizer.from_pretrained('tokenizer')
42
-
43
- return model
44
-
45
  # elif task == 'Question Answering':
46
  #
47
  # # Input area
 
9
  from mt5 import MT5
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # elif task == 'Question Answering':
13
  #
14
  # # Input area