Spaces:
Runtime error
Runtime error
File size: 3,189 Bytes
b785737 2fecb66 b785737 bcda36f 9b51c9a bcda36f b785737 4161351 b785737 835aa93 b785737 af931e5 2718811 c17c61f b4316c5 f98ae08 b785737 2fe2f77 f63d876 fc3188e 2fecb66 b785737 18d5a92 b785737 af931e5 b785737 195e9be b785737 d2bb275 0db29c3 f63d876 2718811 6280603 a93ff49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import transformers
import streamlit as st
from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer = AutoTokenizer.from_pretrained("anonymous-german-nlp/german-gpt2")
@st.cache
def load_model(model_name):
model = AutoModelWithLMHead.from_pretrained("Jipski/Flos_gpt-2_erw-02")
return model
model = load_model("Jipski/Flos_gpt-2_erw")
def infer(input_ids, max_length, temperature, top_k, top_p, num_return_sequences):
output_sequences = model.generate(
input_ids=input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences,
)
return output_sequences
def update_showing():
st.session_state.showing = st.session_state.gen
default_value = "Jetzt tippen!"
#prompts
st.title("Flos gpt-2")
#st.write("The almighty king of text generation, GPT-2 comes in four available sizes, only three of which have been publicly made available. Feared for its fake news generation capabilities, it currently stands as the most syntactically coherent model. A direct successor to the original GPT, it reinforces the already established pre-training/fine-tuning killer duo. From the paper: Language Models are Unsupervised Multitask Learners by Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever.")
sent = st.text_area("Text", default_value, key='showing', height = 275)
max_length = st.sidebar.slider("Max Length", min_value = 50, max_value=500)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.9)
num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=5, value=1, step=1)
encoded_prompt = tokenizer.encode(sent, add_special_tokens=False, return_tensors="pt")
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = infer(input_ids, max_length, temperature, top_k, top_p, num_return_sequences)
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequences = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# Remove all text after the stop token
#text = text[: text.find(args.stop_token) if args.stop_token else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
sent + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
generated_sequences.append(total_sequence)
print(total_sequence)
st.write(generated_sequences[-1])
#st.text_area("Output", generated_sequences[-1], key='gen', height=275, on_change=update_showing)
#st.session_state.catch_rand = generated_sequences[-1]
#st.write(st.session_state.catch_rand)
|