import streamlit as st import torch from time import perf_counter from transformers import AutoTokenizer, AutoModelForCausalLM st.set_page_config( page_title="Romanian Text Generator", page_icon="🇷🇴", layout="wide" ) ############################################# # Python stuff here model_list = [ "dumitrescustefan/gpt-neo-romanian-780m", "readerbench/RoGPT2-base", "readerbench/RoGPT2-medium", "readerbench/RoGPT2-large" ] def greedy_search(model, input_ids, attention_mask, no_repeat_ngram_size, max_length): return model.generate( input_ids=input_ids, attention_mask=attention_mask, no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length ) def beam_search(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, num_beams): return model.generate( input_ids=input_ids, attention_mask=attention_mask, no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length, num_beams=num_beams, early_stopping=True ) def sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, top_k, top_p): return model.generate( input_ids=input_ids, attention_mask=attention_mask, no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length, do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p ) def typical_sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, typical_p): return model.generate( input_ids=input_ids, attention_mask=attention_mask, no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length, do_sample=True, temperature=temperature, typical_p=typical_p, top_k=0 ) @st.cache(allow_output_mutation=True) def setModel(model_checkpoint): model = AutoModelForCausalLM.from_pretrained(model_checkpoint) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) return model, tokenizer ############################################# col_title, _, col_b1, col_b2, col_b3, _ = st.columns([18, 1, 8, 8, 8, 1]) col_title.markdown("**Playground for text generation with Romanian models**") button_greedy = col_b1.button("Greedy generation") button_sampling = col_b2.button("Sampling generation") button_typical = col_b3.button("Typical sampling generation") col1, _, col2 = st.columns([10, 1, 16]) with col1: st.markdown("**Step 1: Select model**") model_checkpoint = st.selectbox("Select model", model_list) st.markdown("**Step 2: Adjust text generation parameters**") max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256) top_k = col1.slider("Top-k", min_value=0, max_value=100, step=10, value=0) top_p = col1.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9) typical_p = col1.slider("Typical-p", min_value=0., max_value=1., step=.10, value=1.0) temperature = st.slider("Temperature", value=1.0, min_value=0.1, max_value=1.0, step=0.1) no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3) ##################################################### # show-time @st.cache(allow_output_mutation=True) def setModel(model_checkpoint): model = AutoModelForCausalLM.from_pretrained(model_checkpoint) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) return model, tokenizer if 'text' not in st.session_state: st.session_state['text'] = 'ScrieÈ›i aici orice text doriÈ›i È™i apoi apăsaÈ›i unul din butoanele de mai sus. Modelul selectat va continua să scrie în continuare' details = "" tokenized_text = None if button_greedy or button_sampling or button_typical: if len(st.session_state['text'].strip()) == 0: col2.warning("Please input some text!") text_element = col2.text_area('Text:', height=400, key="text") st.stop() model, tokenizer = setModel(model_checkpoint) tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt") if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words keep_last = 512 - max_length print(f"keep last: {keep_last}") input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:] previous_ids = tokenized_text.input_ids[0][:keep_last] st.warning(f"kept last {keep_last}") else: input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0] previous_ids = None length = min(512, len(input_ids)+max_length) timer_mark = perf_counter() if button_greedy: output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length) details = f"Text generated using greedy decoding in {perf_counter()-timer_mark:.2f}s" if button_sampling: output = sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, top_k, top_p) details = f"Text generated using sampling, top-p={top_p:.2f}, top-k={top_k}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s" if button_typical: output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, typical_p) details = f"Text generated using typical sampling, typical-p={typical_p:.2f}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s" if previous_ids is not None: print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True)) print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True)) new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True) else: new_text = tokenizer.decode(output[0], skip_special_tokens=True) st.session_state['text'] = new_text text_element = col2.text_area('Text:', height=400, key="text") col2.markdown("""---""") col2.text("Statistics and details:") if details != "": col2.caption("   Generation details: " + details) if tokenized_text is None: tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt") col2.caption(f"   Text length is {len(text_element)} characters, {len(tt.input_ids[0])} tokens.")