d2weber commited on
Commit
d7b460a
1 Parent(s): 486638c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ num_sequences=4
4
+
5
+ demo_mode = False
6
+ if not demo_mode:
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+
9
+ model = AutoModelForCausalLM.from_pretrained("d2weber/german-gpt2-finetuned-coldmirror-hpodcast1")
10
+ tokenizer = AutoTokenizer.from_pretrained("dbmdz/german-gpt2", use_fast=True)
11
+ lm = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
+
13
+ def generate(*args, **kwargs):
14
+ return [o["generated_text"] for o in lm(*args, **kwargs, pad_token_id=tokenizer.eos_token_id)]
15
+
16
+ with gr.Blocks() as app:
17
+ prompt = gr.TextArea(value="Hallo und herzlich willkommen", label="Input")
18
+ sequences = []
19
+ for _ in range(num_sequences):
20
+ seq = gr.Textbox("", visible=False)
21
+ box = gr.CheckboxGroup(choices=[], label="", interactive=True)
22
+ sequences.append(seq)
23
+
24
+ @seq.change(inputs=seq, outputs=box)
25
+ def split(seq):
26
+ return gr.CheckboxGroup(seq.split(), value=[])
27
+
28
+ @box.select(inputs=[prompt, seq], outputs=prompt)
29
+ def handle(prompt, sequence, selected: gr.SelectData):
30
+ to_append = " ".join(sequence.split()[:selected.index+1])
31
+ delimiter = " " if to_append[:1].isalnum() else ""
32
+ return prompt.rstrip() + delimiter + to_append
33
+
34
+ max_new_tokens = gr.Slider(1, 100, value=18, step=1, label="How long should the generated sequences be:")
35
+
36
+ gr.Examples([
37
+ ["Hallo und herzlich willkommen"],
38
+ ], prompt)
39
+
40
+ @prompt.change(inputs=[prompt, max_new_tokens], outputs=sequences)
41
+ def handle(prompt, max_new_tokens):
42
+ prompt = prompt.rstrip()
43
+ texts = ["some new words"]*num_sequences if demo_mode else generate(
44
+ prompt,
45
+ return_full_text=False,
46
+ num_return_sequences=num_sequences,
47
+ max_new_tokens=int(max_new_tokens),
48
+ )
49
+ return texts
50
+
51
+ app.launch()