Stefan Dumitrescu commited on
Commit
c90ce91
1 Parent(s): 19c9e19
Files changed (1) hide show
  1. app.py +17 -40
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  st.set_page_config(
@@ -104,7 +105,7 @@ with col1:
104
  temperature = st.slider("Temperature", value=1.0, min_value=0.1, max_value=1.0, step=0.1)
105
  max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
106
 
107
- st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
108
 
109
 
110
 
@@ -129,6 +130,11 @@ details = ""
129
  tokenized_text = None
130
 
131
  if button_greedy or button_sampling or button_typical:
 
 
 
 
 
132
  model, tokenizer = setModel(model_checkpoint)
133
 
134
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
@@ -144,7 +150,16 @@ if button_greedy or button_sampling or button_typical:
144
  previous_ids = None
145
 
146
  length = min(512, len(input_ids)+max_length)
147
- output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
 
 
 
 
 
 
 
 
 
148
 
149
  if previous_ids is not None:
150
  print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True))
@@ -154,46 +169,8 @@ if button_greedy or button_sampling or button_typical:
154
  new_text = tokenizer.decode(output[0], skip_special_tokens=True)
155
 
156
  st.session_state['text'] = new_text
157
- details = "Text generated using greedy decoding"
158
 
159
- """
160
- if button_greedy:
161
 
162
- tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
163
- print(f"len text: {len(tokenized_text.input_ids[0])}")
164
- print(f"max_len : {max_length}")
165
- if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
166
- keep_last = 512 - max_length
167
- print(f"keep last: {keep_last}")
168
- input_ids, attention_mask = tokenized_text.input_ids[0][:-keep_last], tokenized_text.attention_mask[0][:-keep_last]
169
- st.warning(f"kept last {keep_last}")
170
- else:
171
- input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0]
172
-
173
- length = min(512, len(input_ids)+max_length)
174
- output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
175
- st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
176
- details = "Text generated using greedy decoding"
177
-
178
- if button_sampling:
179
- model, tokenizer = setModel(model_checkpoint)
180
- tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
181
- input_ids = tokenized_text.input_ids
182
- attention_mask = tokenized_text.attention_mask
183
- length = min(512, len(input_ids[0]) + max_length)
184
- output = sampling(model, input_ids, attention_mask, no_repeat_ngrams, length, temperature, top_k, top_p)
185
- st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
186
- details = f"Text generated using sampling, top-p={top_p:.2f}, top-k={top_k:.2f}, temperature={temperature:.2f}"
187
-
188
- if button_typical:
189
- model, tokenizer = setModel(model_checkpoint)
190
- tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
191
- input_ids, attention_mask = tokenized_text.input_ids, tokenized_text.attention_mask
192
- length = min(512, len(input_ids[0]) + max_length)
193
- output = typical_sampling(model, input_ids, attention_mask, no_repeat_ngrams, length, temperature, typical_p)
194
- st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
195
- details = f"Text generated using typical sampling, typical-p={typical_p:.2f}, temperature={temperature:.2f}"
196
- """
197
 
198
  text_element = col2.text_area('Text:', height=400, key="text")
199
  col2.markdown("""---""")
 
1
  import streamlit as st
2
  import torch
3
+ from time import perf_counter
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  st.set_page_config(
 
105
  temperature = st.slider("Temperature", value=1.0, min_value=0.1, max_value=1.0, step=0.1)
106
  max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
107
 
108
+ # st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
109
 
110
 
111
 
 
130
  tokenized_text = None
131
 
132
  if button_greedy or button_sampling or button_typical:
133
+ if len(st.session_state['text'].strip()) == 0:
134
+ col2.warning("Please input some text!")
135
+ text_element = col2.text_area('Text:', height=400, key="text")
136
+ st.stop()
137
+
138
  model, tokenizer = setModel(model_checkpoint)
139
 
140
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
 
150
  previous_ids = None
151
 
152
  length = min(512, len(input_ids)+max_length)
153
+ timer_mark = perf_counter()
154
+ if button_greedy:
155
+ output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
156
+ details = f"Text generated using greedy decoding in {perf_counter()-timer_mark:.2f}s"
157
+ if button_sampling:
158
+ output = sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, top_k, top_p)
159
+ details = f"Text generated using sampling, top-p={top_p:.2f}, top-k={top_k}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s"
160
+ if button_typical:
161
+ output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, typical_p)
162
+ details = f"Text generated using typical sampling, typical-p={typical_p:.2f}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s"
163
 
164
  if previous_ids is not None:
165
  print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True))
 
169
  new_text = tokenizer.decode(output[0], skip_special_tokens=True)
170
 
171
  st.session_state['text'] = new_text
 
172
 
 
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  text_element = col2.text_area('Text:', height=400, key="text")
176
  col2.markdown("""---""")