import streamlit as st import torch from time import perf_counter from transformers import AutoTokenizer, AutoModelForCausalLM MODEL = 'nan-dre/maneleGPT-medium' TOKENIZER = 'nan-dre/maneleGPT-medium' MAX_LENGTH = 256 st.set_page_config( page_title="ManeleGPT", page_icon="🇷🇴", layout="centered" ) def typical_sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, typical_p): return model.generate( input_ids=input_ids, attention_mask=attention_mask, no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length, do_sample=True, temperature=temperature, typical_p=typical_p, top_k=0 ) @st.cache(allow_output_mutation=True) def setModel(): model = AutoModelForCausalLM.from_pretrained(MODEL) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) return model, tokenizer st.header("ManeleGPT") temperature = st.slider(label="Temperatura", min_value=0.01, max_value=1.0, value=0.5, step=0.01) input = st.text_input(label="Cu ce vers sa inceapa maneaua?", value="", key="seed") if input: model, tokenizer = setModel() tokenized_text = tokenizer(input, add_special_tokens=False, return_tensors="pt") if len(tokenized_text.input_ids[0]) + MAX_LENGTH > 512: # need to keep less words keep_last = 512 - MAX_LENGTH print(f"keep last: {keep_last}") input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:] previous_ids = tokenized_text.input_ids[0][:keep_last] st.warning(f"kept last {keep_last}") else: input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0] previous_ids = None length = min(512, len(input_ids) + MAX_LENGTH) timer_mark = perf_counter() output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngram_size=2, max_length=MAX_LENGTH, temperature=temperature, typical_p=1) details = f"Text generated in {perf_counter()-timer_mark:.2f}s" if previous_ids is not None: print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True)) print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True)) new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True) else: new_text = tokenizer.decode(output[0], skip_special_tokens=True) st.text(new_text)