Pippoz commited on
Commit
8f074bc
1 Parent(s): 5c4eafa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -11
app.py CHANGED
@@ -3,20 +3,45 @@ import time
3
  from transformers import pipeline
4
  import torch
5
 
6
- st.markdown('## OPT-1.3 Billion parameter (Meta)')
7
 
8
- with st.spinner('Loading Model... (This may take a while)'):
9
- generator = pipeline('text-generation', model="facebook/opt-1.3b", skip_special_tokens=True)
10
- st.success('Model loaded correctly!')
11
 
12
- prompt= st.text_area('Your prompt here',
13
- '''AI will help humanity?''')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  answer = generator(prompt, max_length=60, no_repeat_ngram_size=2, early_stopping=True, num_beams=5)
16
-
17
  lst = answer[0]['generated_text']
18
 
19
- t = st.empty()
20
- for i in range(len(lst)):
21
- t.markdown(" %s..." % lst[0:i])
22
- time.sleep(0.04)
 
 
3
  from transformers import pipeline
4
  import torch
5
 
6
+ st.markdown('## Text-generation OPT from Meta ')
7
 
8
+ @st.cache(allow_output_mutation=True)
9
+ def get_model():
10
+ return pipeline('text-generation', model=model)
11
 
12
+ col1, col2 = st.beta_columns([2,1])
13
+
14
+ with col2:
15
+ select_model = st.radio(
16
+ "Select the model to use:",
17
+ ('OPT-125m', 'OPT-350m', 'OPT-1.3b'))
18
+
19
+ if select_model == 'OPT-1.3b':
20
+ model = 'facebook/opt-1.3b'
21
+ elif select_model == 'OPT-350m':
22
+ model = 'facebook/opt-350m'
23
+ elif select_model == 'OPT-125m':
24
+ model = 'facebook/opt-125m'
25
+
26
+ if select_model:
27
+ with st.spinner('Loading Model... (This may take a while)'):
28
+ generator = get_model()
29
+ #time.sleep(2)
30
+ st.success('Model loaded correctly!')
31
+
32
+
33
+ with col1:
34
+ prompt= st.text_area('Your prompt here',
35
+ '''AI will help humanity?''')
36
+
37
+ # answer = ['ciao come stai stutto bene']
38
+ # lst = ''.join(answer)
39
 
40
  answer = generator(prompt, max_length=60, no_repeat_ngram_size=2, early_stopping=True, num_beams=5)
 
41
  lst = answer[0]['generated_text']
42
 
43
+
44
+ t = st.empty()
45
+ for i in range(len(lst)):
46
+ t.markdown(" %s..." % lst[0:i])
47
+ time.sleep(0.04)