maneleGPT / app.py
nan-dre's picture
Replace cache_resource with cache
3ed9d82
import streamlit as st
import torch
from time import perf_counter
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL = 'nan-dre/maneleGPT-medium'
TOKENIZER = 'nan-dre/maneleGPT-medium'
MAX_LENGTH = 256
st.set_page_config(
page_title="ManeleGPT",
page_icon="πŸ‡·πŸ‡΄",
layout="centered"
)
def typical_sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, typical_p):
return model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
no_repeat_ngram_size=no_repeat_ngram_size,
max_length=max_length,
do_sample=True,
temperature=temperature,
typical_p=typical_p,
top_k=0
)
@st.cache(allow_output_mutation=True)
def setModel():
model = AutoModelForCausalLM.from_pretrained(MODEL)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
return model, tokenizer
st.header("ManeleGPT")
temperature = st.slider(label="Temperatura", min_value=0.01, max_value=1.0, value=0.5, step=0.01)
input = st.text_input(label="Cu ce vers sa inceapa maneaua?", value="", key="seed")
if input:
model, tokenizer = setModel()
tokenized_text = tokenizer(input, add_special_tokens=False, return_tensors="pt")
if len(tokenized_text.input_ids[0]) + MAX_LENGTH > 512: # need to keep less words
keep_last = 512 - MAX_LENGTH
print(f"keep last: {keep_last}")
input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:]
previous_ids = tokenized_text.input_ids[0][:keep_last]
st.warning(f"kept last {keep_last}")
else:
input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0]
previous_ids = None
length = min(512, len(input_ids) + MAX_LENGTH)
timer_mark = perf_counter()
output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngram_size=2, max_length=MAX_LENGTH, temperature=temperature, typical_p=1)
details = f"Text generated in {perf_counter()-timer_mark:.2f}s"
if previous_ids is not None:
print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True))
print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True))
new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True)
else:
new_text = tokenizer.decode(output[0], skip_special_tokens=True)
st.text(new_text)