|
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline |
|
import torch |
|
import streamlit as st |
|
|
|
|
|
|
|
MODELS={ |
|
'uribe':{ |
|
'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/uribe-twitter-assistant-30ep"), |
|
'model':AutoModelForCausalLM.from_pretrained("jhonparra18/uribe-twitter-assistant-30ep")}, |
|
'petro':{ |
|
'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large"), |
|
'model':AutoModelForCausalLM.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large")}} |
|
|
|
def callback_input_text(new_text): |
|
del st.session_state.input_user_txt |
|
st.session_state.input_user_txt=new_text |
|
|
|
def text_completion(tokenizer,model,input_text:str,max_len:int=100): |
|
tokenizer.padding_side="left" |
|
tokenizer.pad_token = tokenizer.eos_token |
|
input_ids = tokenizer([input_text], return_tensors="pt",truncation=True,max_length=128) |
|
with torch.no_grad(): |
|
outputs = model.generate(**input_ids, do_sample=True, max_length=max_len,top_k=100,top_p=0.95) |
|
out_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
return out_text |
|
|
|
|
|
|
|
|
|
st.markdown("<h3 style='text-align: center; color: gray;'> 🐦 Tweet de Pol铆tico Colombiano: Autocompletado/generaci贸n de texto a partir de GPT2</h3>", unsafe_allow_html=True) |
|
st.text("") |
|
st.markdown("<h5 style='text-align: center; color: gray;'>Causal Language Modeling, source code <a href='https://github.com/statscol/twitter-user-autocomplete-assistant'> here </a> </h5>", unsafe_allow_html=True) |
|
st.text("") |
|
|
|
|
|
col1,col2 = st.columns(2) |
|
|
|
with col1: |
|
with st.form("input_values"): |
|
politician = st.selectbox( |
|
"Selecciona el pol铆tico", |
|
("Uribe", "Petro") |
|
) |
|
st.text("") |
|
max_length_text=st.slider('Num Max Tokens', 50, 200, 100,10,key="user_max_length") |
|
st.text("") |
|
input_user_text=st.empty() |
|
input_text_value=input_user_text.text_area('Input Text', 'Mi gobierno no es corrupto',key="input_user_txt",height=300) |
|
st.text("") |
|
complete_input=st.checkbox("Complete Input [Experimental]",value=False,help="Autom谩ticamente rellenar el texto inicial con el resultado para una nueva iteraci贸n") |
|
go_button=st.form_submit_button('Generate') |
|
|
|
|
|
with col2: |
|
|
|
if go_button: |
|
with st.spinner('Generating Text...'): |
|
output_text=text_completion(MODELS[politician.lower()]['tokenizer'],MODELS[politician.lower()]['model'],input_text_value,max_length_text) |
|
st.text_area("Tweet:",output_text,height=500,key="output_text") |
|
if complete_input: |
|
callback_input_text(output_text) |
|
input_user_text.text_area("Input Text", output_text,height=300) |
|
|
|
|