jhon parra
commited on
Commit
•
586df44
1
Parent(s):
1457ccb
updated app behaviour
Browse files
app.py
CHANGED
@@ -12,16 +12,22 @@ MODELS={
|
|
12 |
'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large"),
|
13 |
'model':AutoModelForCausalLM.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large")}}
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
def text_completion(tokenizer,model,input_text:str,max_len:int=100):
|
17 |
tokenizer.padding_side="left" ##start padding from left to right
|
18 |
tokenizer.pad_token = tokenizer.eos_token
|
19 |
input_ids = tokenizer([input_text], return_tensors="pt",truncation=True,max_length=128)
|
20 |
-
|
|
|
21 |
out_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
22 |
return out_text
|
23 |
|
24 |
|
|
|
|
|
25 |
st.markdown("<h3 style='text-align: center; color: gray;'> 🐦 Tweet de Político Colombiano: Autocompletado/generación de texto a partir de GPT2</h3>", unsafe_allow_html=True)
|
26 |
st.text("")
|
27 |
st.markdown("<h5 style='text-align: center; color: gray;'>Causal Language Modeling, source code <a href='https://github.com/statscol/twitter-user-autocomplete-assistant'> here </a> </h5>", unsafe_allow_html=True)
|
@@ -30,7 +36,6 @@ st.text("")
|
|
30 |
|
31 |
col1,col2 = st.columns(2)
|
32 |
|
33 |
-
|
34 |
with col1:
|
35 |
with st.form("input_values"):
|
36 |
politician = st.selectbox(
|
@@ -38,20 +43,24 @@ with col1:
|
|
38 |
("Uribe", "Petro")
|
39 |
)
|
40 |
st.text("")
|
41 |
-
max_length_text=st.slider('Num Max Tokens',
|
42 |
st.text("")
|
43 |
-
input_user_text
|
|
|
44 |
st.text("")
|
45 |
-
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
with col2:
|
49 |
|
50 |
if go_button: ##avoid re running script
|
51 |
with st.spinner('Generating Text...'):
|
52 |
-
output_text=text_completion(MODELS[politician.lower()]['tokenizer'],MODELS[politician.lower()]['model'],
|
53 |
-
st.text_area("Tweet:",output_text,height=
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
|
|
12 |
'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large"),
|
13 |
'model':AutoModelForCausalLM.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large")}}
|
14 |
|
15 |
+
def callback_input_text(new_text):
|
16 |
+
del st.session_state.input_user_txt
|
17 |
+
st.session_state.input_user_txt=new_text
|
18 |
|
19 |
def text_completion(tokenizer,model,input_text:str,max_len:int=100):
|
20 |
tokenizer.padding_side="left" ##start padding from left to right
|
21 |
tokenizer.pad_token = tokenizer.eos_token
|
22 |
input_ids = tokenizer([input_text], return_tensors="pt",truncation=True,max_length=128)
|
23 |
+
with torch.no_grad(): ##maybe useless as the generate method does not compute gradients, just in case
|
24 |
+
outputs = model.generate(**input_ids, do_sample=True, max_length=max_len,top_k=100,top_p=0.95)
|
25 |
out_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
26 |
return out_text
|
27 |
|
28 |
|
29 |
+
|
30 |
+
|
31 |
st.markdown("<h3 style='text-align: center; color: gray;'> 🐦 Tweet de Político Colombiano: Autocompletado/generación de texto a partir de GPT2</h3>", unsafe_allow_html=True)
|
32 |
st.text("")
|
33 |
st.markdown("<h5 style='text-align: center; color: gray;'>Causal Language Modeling, source code <a href='https://github.com/statscol/twitter-user-autocomplete-assistant'> here </a> </h5>", unsafe_allow_html=True)
|
|
|
36 |
|
37 |
col1,col2 = st.columns(2)
|
38 |
|
|
|
39 |
with col1:
|
40 |
with st.form("input_values"):
|
41 |
politician = st.selectbox(
|
|
|
43 |
("Uribe", "Petro")
|
44 |
)
|
45 |
st.text("")
|
46 |
+
max_length_text=st.slider('Num Max Tokens', 50, 200, 100,10,key="user_max_length")
|
47 |
st.text("")
|
48 |
+
input_user_text=st.empty()
|
49 |
+
input_text_value=input_user_text.text_area('Input Text', 'Mi gobierno no es corrupto',key="input_user_txt",height=300)
|
50 |
st.text("")
|
51 |
+
complete_input=st.checkbox("Complete Input [Experimental]",value=False,help="Automáticamente rellenar el texto inicial con el resultado para una nueva iteración")
|
52 |
+
_,col_center,_=st.columns(3)
|
53 |
+
with col_center:
|
54 |
+
go_button=st.form_submit_button('Generate')
|
55 |
|
56 |
|
57 |
with col2:
|
58 |
|
59 |
if go_button: ##avoid re running script
|
60 |
with st.spinner('Generating Text...'):
|
61 |
+
output_text=text_completion(MODELS[politician.lower()]['tokenizer'],MODELS[politician.lower()]['model'],input_text_value,max_length_text)
|
62 |
+
st.text_area("Tweet:",output_text,height=500,key="output_text")
|
63 |
+
if complete_input:
|
64 |
+
callback_input_text(output_text)
|
65 |
+
input_user_text.text_area("Input Text", output_text,height=300)
|
66 |
|