Jipski commited on
Commit
74a9667
1 Parent(s): b785737

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -1,15 +1,13 @@
1
  import transformers
2
  import streamlit as st
3
-
4
  from transformers import AutoTokenizer, AutoModelWithLMHead
5
-
6
- tokenizer='anonymous-german-nlp/german-gpt2'
7
-
8
  @st.cache
9
  def load_model(model_name):
10
- model = AutoModelWithLMHead.from_pretrained("Flos_gpt-2")
11
  return model
12
- model = load_model("Flos_gpt-2")
13
  def infer(input_ids, max_length, temperature, top_k, top_p, num_return_sequences):
14
  output_sequences = model.generate(
15
  input_ids=input_ids,
@@ -50,4 +48,4 @@ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
50
  )
51
  generated_sequences.append(total_sequence)
52
  print(total_sequence)
53
- st.write(generated_sequences[-1])
1
  import transformers
2
  import streamlit as st
 
3
  from transformers import AutoTokenizer, AutoModelWithLMHead
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
 
6
  @st.cache
7
  def load_model(model_name):
8
+ model = AutoModelWithLMHead.from_pretrained("gpt2-large")
9
  return model
10
+ model = load_model("gpt2-large")
11
  def infer(input_ids, max_length, temperature, top_k, top_p, num_return_sequences):
12
  output_sequences = model.generate(
13
  input_ids=input_ids,
48
  )
49
  generated_sequences.append(total_sequence)
50
  print(total_sequence)
51
+ st.write(generated_sequences[-1])