File size: 3,189 Bytes
b785737
2fecb66
b785737
bcda36f
9b51c9a
bcda36f
 
b785737
 
4161351
b785737
835aa93
b785737
 
 
 
 
 
 
 
 
 
 
af931e5
2718811
c17c61f
b4316c5
f98ae08
b785737
2fe2f77
 
f63d876
fc3188e
2fecb66
 
 
b785737
 
 
 
 
 
 
18d5a92
b785737
af931e5
b785737
 
 
 
 
 
 
 
195e9be
b785737
 
 
d2bb275
0db29c3
f63d876
2718811
6280603
 
a93ff49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import transformers
import streamlit as st
from transformers import AutoTokenizer, AutoModelWithLMHead

tokenizer = AutoTokenizer.from_pretrained("anonymous-german-nlp/german-gpt2")


@st.cache
def load_model(model_name):
    model = AutoModelWithLMHead.from_pretrained("Jipski/Flos_gpt-2_erw-02")
    return model
model = load_model("Jipski/Flos_gpt-2_erw")
def infer(input_ids, max_length, temperature, top_k, top_p, num_return_sequences):
    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=num_return_sequences,
    )
    return output_sequences

def update_showing():
    st.session_state.showing = st.session_state.gen
    
default_value = "Jetzt tippen!"
#prompts
st.title("Flos gpt-2")
#st.write("The almighty king of text generation, GPT-2 comes in four available sizes, only three of which have been publicly made available. Feared for its fake news generation capabilities, it currently stands as the most syntactically coherent model. A direct successor to the original GPT, it reinforces the already established pre-training/fine-tuning killer duo. From the paper: Language Models are Unsupervised Multitask Learners by Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever.")
sent = st.text_area("Text", default_value, key='showing', height = 275)
max_length = st.sidebar.slider("Max Length", min_value = 50, max_value=500)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.9)
num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=5, value=1, step=1)
encoded_prompt = tokenizer.encode(sent, add_special_tokens=False, return_tensors="pt")
if encoded_prompt.size()[-1] == 0:
    input_ids = None
else:
    input_ids = encoded_prompt
output_sequences = infer(input_ids, max_length, temperature, top_k, top_p, num_return_sequences)

for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
    
    print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
    generated_sequences = generated_sequence.tolist()
    # Decode text
    text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
    # Remove all text after the stop token
    #text = text[: text.find(args.stop_token) if args.stop_token else None]
    # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
    total_sequence = (
        sent + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
    )
    generated_sequences.append(total_sequence)
    print(total_sequence)
    
st.write(generated_sequences[-1])      
#st.text_area("Output", generated_sequences[-1], key='gen', height=275, on_change=update_showing)

#st.session_state.catch_rand = generated_sequences[-1]
#st.write(st.session_state.catch_rand)