|
import streamlit as st |
|
import torch |
|
from time import perf_counter |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
MODEL = 'nan-dre/maneleGPT-medium' |
|
TOKENIZER = 'nan-dre/maneleGPT-medium' |
|
MAX_LENGTH = 256 |
|
|
|
st.set_page_config( |
|
page_title="ManeleGPT", |
|
page_icon="π·π΄", |
|
layout="centered" |
|
) |
|
|
|
def typical_sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, typical_p): |
|
return model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
max_length=max_length, |
|
do_sample=True, |
|
temperature=temperature, |
|
typical_p=typical_p, |
|
top_k=0 |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def setModel(): |
|
model = AutoModelForCausalLM.from_pretrained(MODEL) |
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
|
return model, tokenizer |
|
|
|
st.header("ManeleGPT") |
|
temperature = st.slider(label="Temperatura", min_value=0.01, max_value=1.0, value=0.5, step=0.01) |
|
input = st.text_input(label="Cu ce vers sa inceapa maneaua?", value="", key="seed") |
|
|
|
if input: |
|
model, tokenizer = setModel() |
|
|
|
tokenized_text = tokenizer(input, add_special_tokens=False, return_tensors="pt") |
|
|
|
if len(tokenized_text.input_ids[0]) + MAX_LENGTH > 512: |
|
keep_last = 512 - MAX_LENGTH |
|
print(f"keep last: {keep_last}") |
|
input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:] |
|
previous_ids = tokenized_text.input_ids[0][:keep_last] |
|
st.warning(f"kept last {keep_last}") |
|
else: |
|
input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0] |
|
previous_ids = None |
|
|
|
length = min(512, len(input_ids) + MAX_LENGTH) |
|
timer_mark = perf_counter() |
|
output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngram_size=2, max_length=MAX_LENGTH, temperature=temperature, typical_p=1) |
|
details = f"Text generated in {perf_counter()-timer_mark:.2f}s" |
|
|
|
|
|
if previous_ids is not None: |
|
print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True)) |
|
print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True)) |
|
new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True) |
|
else: |
|
new_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
st.text(new_text) |