jhon parra commited on
Commit
586df44
1 Parent(s): 1457ccb

updated app behaviour

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -12,16 +12,22 @@ MODELS={
12
  'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large"),
13
  'model':AutoModelForCausalLM.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large")}}
14
 
 
 
 
15
 
16
  def text_completion(tokenizer,model,input_text:str,max_len:int=100):
17
  tokenizer.padding_side="left" ##start padding from left to right
18
  tokenizer.pad_token = tokenizer.eos_token
19
  input_ids = tokenizer([input_text], return_tensors="pt",truncation=True,max_length=128)
20
- outputs = model.generate(**input_ids, do_sample=True, max_length=max_len,top_k=100,top_p=0.95)
 
21
  out_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
22
  return out_text
23
 
24
 
 
 
25
  st.markdown("<h3 style='text-align: center; color: gray;'> &#128038 Tweet de Político Colombiano: Autocompletado/generación de texto a partir de GPT2</h3>", unsafe_allow_html=True)
26
  st.text("")
27
  st.markdown("<h5 style='text-align: center; color: gray;'>Causal Language Modeling, source code <a href='https://github.com/statscol/twitter-user-autocomplete-assistant'> here </a> </h5>", unsafe_allow_html=True)
@@ -30,7 +36,6 @@ st.text("")
30
 
31
  col1,col2 = st.columns(2)
32
 
33
-
34
  with col1:
35
  with st.form("input_values"):
36
  politician = st.selectbox(
@@ -38,20 +43,24 @@ with col1:
38
  ("Uribe", "Petro")
39
  )
40
  st.text("")
41
- max_length_text=st.slider('Num Max Tokens', 100, 200, 100,10,key="user_max_length")
42
  st.text("")
43
- input_user_text = st.text_area('Input Text', 'Mi gobierno es',key="input_user_txt")
 
44
  st.text("")
45
- go_button=st.form_submit_button('Generate')
 
 
 
46
 
47
 
48
  with col2:
49
 
50
  if go_button: ##avoid re running script
51
  with st.spinner('Generating Text...'):
52
- output_text=text_completion(MODELS[politician.lower()]['tokenizer'],MODELS[politician.lower()]['model'],input_user_text,max_length_text)
53
- st.text_area("Tweet:",output_text,height=380,key="output_text")
54
- else:
55
- st.text_area("Tweet:","",height=380,key="output_text")
56
-
57
 
 
12
  'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large"),
13
  'model':AutoModelForCausalLM.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large")}}
14
 
15
+ def callback_input_text(new_text):
16
+ del st.session_state.input_user_txt
17
+ st.session_state.input_user_txt=new_text
18
 
19
  def text_completion(tokenizer,model,input_text:str,max_len:int=100):
20
  tokenizer.padding_side="left" ##start padding from left to right
21
  tokenizer.pad_token = tokenizer.eos_token
22
  input_ids = tokenizer([input_text], return_tensors="pt",truncation=True,max_length=128)
23
+ with torch.no_grad(): ##maybe useless as the generate method does not compute gradients, just in case
24
+ outputs = model.generate(**input_ids, do_sample=True, max_length=max_len,top_k=100,top_p=0.95)
25
  out_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
26
  return out_text
27
 
28
 
29
+
30
+
31
  st.markdown("<h3 style='text-align: center; color: gray;'> &#128038 Tweet de Político Colombiano: Autocompletado/generación de texto a partir de GPT2</h3>", unsafe_allow_html=True)
32
  st.text("")
33
  st.markdown("<h5 style='text-align: center; color: gray;'>Causal Language Modeling, source code <a href='https://github.com/statscol/twitter-user-autocomplete-assistant'> here </a> </h5>", unsafe_allow_html=True)
 
36
 
37
  col1,col2 = st.columns(2)
38
 
 
39
  with col1:
40
  with st.form("input_values"):
41
  politician = st.selectbox(
 
43
  ("Uribe", "Petro")
44
  )
45
  st.text("")
46
+ max_length_text=st.slider('Num Max Tokens', 50, 200, 100,10,key="user_max_length")
47
  st.text("")
48
+ input_user_text=st.empty()
49
+ input_text_value=input_user_text.text_area('Input Text', 'Mi gobierno no es corrupto',key="input_user_txt",height=300)
50
  st.text("")
51
+ complete_input=st.checkbox("Complete Input [Experimental]",value=False,help="Automáticamente rellenar el texto inicial con el resultado para una nueva iteración")
52
+ _,col_center,_=st.columns(3)
53
+ with col_center:
54
+ go_button=st.form_submit_button('Generate')
55
 
56
 
57
  with col2:
58
 
59
  if go_button: ##avoid re running script
60
  with st.spinner('Generating Text...'):
61
+ output_text=text_completion(MODELS[politician.lower()]['tokenizer'],MODELS[politician.lower()]['model'],input_text_value,max_length_text)
62
+ st.text_area("Tweet:",output_text,height=500,key="output_text")
63
+ if complete_input:
64
+ callback_input_text(output_text)
65
+ input_user_text.text_area("Input Text", output_text,height=300)
66