breadlicker45 commited on
Commit
392eb9f
1 Parent(s): 768480e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -3,7 +3,7 @@ import time
3
  from transformers import pipeline
4
  import torch
5
  trust_remote_code=True
6
- st.markdown('## Text-generation OPT from Facebook')
7
 
8
  @st.cache(allow_output_mutation=True, suppress_st_warning =True, show_spinner=False)
9
  def get_model():
@@ -14,9 +14,9 @@ col1, col2 = st.columns([2,1])
14
  with st.sidebar:
15
  st.markdown('## Model Parameters')
16
 
17
- max_length = st.slider('Max text length', 0, 250, 80)
18
 
19
- num_beams = st.slider('N° tree beams search', 2, 15, 0)
20
 
21
  early_stopping = st.selectbox(
22
  'Early stopping text generation',
@@ -26,19 +26,19 @@ with st.sidebar:
26
 
27
  with col1:
28
  prompt= st.text_area('Your prompt here',
29
- '''Who is Elon Musk?''')
30
 
31
  with col2:
32
  select_model = st.radio(
33
  "Select the model to use:",
34
- ('OPT-125m', 'OPT-350m', 'OPT-1.3b'), index = 1)
35
 
36
- if select_model == 'OPT-1.3b':
37
- model = 'facebook/opt-1.3b'
38
- elif select_model == 'OPT-350m':
39
  model = 'breadlicker45/MusePy'
40
- elif select_model == 'OPT-125m':
41
- model = 'BAAI/glm-515m'
42
 
43
  with st.spinner('Loading Model... (This may take a while)'):
44
  generator = get_model()
 
3
  from transformers import pipeline
4
  import torch
5
  trust_remote_code=True
6
+ st.markdown('## Text-generation gpt Muse models from Breadlicker')
7
 
8
  @st.cache(allow_output_mutation=True, suppress_st_warning =True, show_spinner=False)
9
  def get_model():
 
14
  with st.sidebar:
15
  st.markdown('## Model Parameters')
16
 
17
+ max_length = st.slider('Max text length', 0, 500, 80)
18
 
19
+ num_beams = st.slider('N° tree beams search', 2, 15, 2)
20
 
21
  early_stopping = st.selectbox(
22
  'Early stopping text generation',
 
26
 
27
  with col1:
28
  prompt= st.text_area('Your prompt here',
29
+ '''2623 2619 3970 3976 2607 3973 2735 3973 2598 3985 2726 3973 2607 4009 2735 3973 2598 3973 2726 3973 2607 3973 2735 4009''')
30
 
31
  with col2:
32
  select_model = st.radio(
33
  "Select the model to use:",
34
+ ('MuseWeb', 'MusePy', 'MuseNeo'), index = 2)
35
 
36
+ if select_model == 'MuseWeb':
37
+ model = 'breadlicker45/MuseWeb'
38
+ elif select_model == 'MusePy':
39
  model = 'breadlicker45/MusePy'
40
+ elif select_model == 'MuseNeo':
41
+ model = 'breadlicker45/MuseNeo'
42
 
43
  with st.spinner('Loading Model... (This may take a while)'):
44
  generator = get_model()