File size: 1,785 Bytes
7d5081f
 
4946d76
7d5081f
 
0fffa24
c9ee852
8510941
8f074bc
c0b285b
d234736
378eb4f
8f074bc
d3e6642
 
 
 
 
39d8890
d3e6642
 
 
 
 
 
 
d234736
 
80882a3
d234736
8f074bc
 
 
6b59e75
8f074bc
 
 
 
 
 
6b59e75
8f074bc
662f9f2
 
 
d234736
662f9f2
d234736
39d8890
600205e
662f9f2
378eb4f
d234736
 
429d718
 
259c183
429d718
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
import time
from transformers import pipeline
import torch

st.markdown('## Text-generation OPT from  Facebook')

@st.cache(allow_output_mutation=True, suppress_st_warning =True, show_spinner=False)
def get_model():
    return pipeline('text-generation', model=model, do_sample=True)
    
col1, col2 = st.columns([2,1])

with st.sidebar:
    st.markdown('## Model Parameters')

    max_length = st.slider('Max text length', 0, 150, 80)

    num_beams = st.slider('N° tree beams search', 2, 15,  5)

    early_stopping = st.selectbox(
     'Early stopping text generation',
     ('True', 'False'), key={'True' : True, 'False': False}, index=0)

    no_ngram_repeat = st.slider('Max repetition limit', 1, 5,  2)
    
with col1:
    prompt= st.text_area('Your prompt here',
        '''Who is Elon Musk?''') 
        
with col2:
    select_model = st.radio(
        "Select the model to use:",
        ('OPT-125m', 'OPT-350m', 'OPT-1.3b'), index = 1)

    if select_model == 'OPT-1.3b':
        model = 'facebook/opt-1.3b'
    elif select_model == 'OPT-350m':
        model = 'facebook/opt-350m'
    elif select_model == 'OPT-125m':
        model = 'facebook/opt-125m'   

    with st.spinner('Loading Model... (This may take a while)'):
        generator = get_model()    
        st.success('Model loaded correctly!')
     
gen = st.info('Generating text...')
answer = generator(prompt,
                       max_length=max_length, no_repeat_ngram_size=no_ngram_repeat,
                        early_stopping=early_stopping, num_beams=num_beams)                      
gen.empty()                      
                       
lst = answer[0]['generated_text']
   
t = st.empty()
for i in range(len(lst)):
    t.markdown("#### %s" % lst[0:i])
    time.sleep(0.04)