Spaces:
Runtime error
Runtime error
readded state
Browse files
app.py
CHANGED
@@ -19,9 +19,12 @@ generation_config = GenerationConfig.from_pretrained('gpt2-medium')
|
|
19 |
generation_config.max_new_tokens = response_length
|
20 |
generation_config.pad_token_id = generation_config.eos_token_id
|
21 |
|
22 |
-
def generate_response(story_so_far, new_prompt):
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
set_seed(random.randint(0, 4000000000))
|
27 |
inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
|
@@ -29,9 +32,21 @@ def generate_response(story_so_far, new_prompt):
|
|
29 |
max_length=1024 - response_length)
|
30 |
|
31 |
output = model.generate(inputs, do_sample=True, generation_config=generation_config)
|
32 |
-
response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(truncated_story) + 1) if truncated_story else 0):])
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def clean_paragraph(entry):
|
37 |
paragraphs = entry.split('\n')
|
@@ -44,9 +59,10 @@ def clean_paragraph(entry):
|
|
44 |
return capitalize_first_char("\n".join(paragraphs))
|
45 |
|
46 |
def reset():
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
def capitalize_first_char(entry):
|
52 |
for i in range(len(entry)):
|
@@ -58,17 +74,19 @@ with gr.Blocks() as demo:
|
|
58 |
story = gr.Textbox(interactive=False, lines=20)
|
59 |
story.style(show_copy_button=True)
|
60 |
|
|
|
|
|
61 |
prompt = gr.Textbox(placeholder="Continue the story here!", lines=3, max_lines=3)
|
62 |
|
63 |
with gr.Row():
|
64 |
gen_button = gr.Button('Generate')
|
65 |
-
|
66 |
res_button = gr.Button("Reset")
|
67 |
|
68 |
-
prompt.submit(generate_response, [
|
69 |
-
gen_button.click(generate_response, [
|
70 |
-
|
71 |
-
res_button.click(reset, [], story, scroll_to_output=True)
|
72 |
|
73 |
demo.launch(inbrowser=True)
|
74 |
|
|
|
19 |
generation_config.max_new_tokens = response_length
|
20 |
generation_config.pad_token_id = generation_config.eos_token_id
|
21 |
|
|
|
22 |
|
23 |
+
|
24 |
+
|
25 |
+
def generate_response(outputs, new_prompt):
|
26 |
+
|
27 |
+
story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)])
|
28 |
|
29 |
set_seed(random.randint(0, 4000000000))
|
30 |
inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
|
|
|
32 |
max_length=1024 - response_length)
|
33 |
|
34 |
output = model.generate(inputs, do_sample=True, generation_config=generation_config)
|
|
|
35 |
|
36 |
+
response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(story_so_far) + 1) if story_so_far else 0):])
|
37 |
+
outputs.append(response)
|
38 |
+
return {
|
39 |
+
user_outputs: outputs,
|
40 |
+
story: ((story_so_far + '\n' if story_so_far else '') + response).replace('\n', '\n\n')
|
41 |
+
}
|
42 |
+
|
43 |
+
def undo(outputs):
|
44 |
+
|
45 |
+
outputs = outputs[:-1]
|
46 |
+
return {
|
47 |
+
user_outputs: outputs,
|
48 |
+
story: "\n".join(outputs)
|
49 |
+
}
|
50 |
|
51 |
def clean_paragraph(entry):
|
52 |
paragraphs = entry.split('\n')
|
|
|
59 |
return capitalize_first_char("\n".join(paragraphs))
|
60 |
|
61 |
def reset():
|
62 |
+
return {
|
63 |
+
user_outputs: None,
|
64 |
+
story: None
|
65 |
+
}
|
66 |
|
67 |
def capitalize_first_char(entry):
|
68 |
for i in range(len(entry)):
|
|
|
74 |
story = gr.Textbox(interactive=False, lines=20)
|
75 |
story.style(show_copy_button=True)
|
76 |
|
77 |
+
user_outputs = gr.State()
|
78 |
+
|
79 |
prompt = gr.Textbox(placeholder="Continue the story here!", lines=3, max_lines=3)
|
80 |
|
81 |
with gr.Row():
|
82 |
gen_button = gr.Button('Generate')
|
83 |
+
undo_button = gr.Button("Undo")
|
84 |
res_button = gr.Button("Reset")
|
85 |
|
86 |
+
prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
|
87 |
+
gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
|
88 |
+
undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
|
89 |
+
res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)
|
90 |
|
91 |
demo.launch(inbrowser=True)
|
92 |
|