coffeeee commited on
Commit
cfc7981
1 Parent(s): 1239a24

fixed global var wack stuff

Browse files
Files changed (1) hide show
  1. app.py +10 -26
app.py CHANGED
@@ -19,35 +19,19 @@ generation_config = GenerationConfig.from_pretrained('gpt2-medium')
19
  generation_config.max_new_tokens = response_length
20
  generation_config.pad_token_id = generation_config.eos_token_id
21
 
 
22
 
23
- outputs = []
24
 
25
-
26
- def generate_response(new_prompt):
27
- print('a')
28
- global outputs
29
- print('b')
30
- story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)])
31
- print('c')
32
  set_seed(random.randint(0, 4000000000))
33
  inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
34
  return_tensors='pt', truncation=True,
35
  max_length=1024 - response_length)
36
- print('d')
37
  output = model.generate(inputs, do_sample=True, generation_config=generation_config)
38
- print('e')
39
- response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(story_so_far) + 1) if story_so_far else 0):])
40
- print('f')
41
- outputs.append(response)
42
- print('g')
43
- return ((story_so_far + '\n' if story_so_far else '') + response).replace('\n', '\n\n')
44
 
45
- def undo():
46
- global outputs
47
- print(outputs)
48
- outputs = outputs[:-1]
49
- print(outputs)
50
- return "\n".join(outputs).replace('\n', '\n\n')
51
 
52
  def clean_paragraph(entry):
53
  paragraphs = entry.split('\n')
@@ -60,8 +44,8 @@ def clean_paragraph(entry):
60
  return capitalize_first_char("\n".join(paragraphs))
61
 
62
  def reset():
63
- global outputs
64
- outputs = []
65
  return None
66
 
67
  def capitalize_first_char(entry):
@@ -78,12 +62,12 @@ with gr.Blocks() as demo:
78
 
79
  with gr.Row():
80
  gen_button = gr.Button('Generate')
81
- undo_button = gr.Button("Undo")
82
  res_button = gr.Button("Reset")
83
 
84
  prompt.submit(generate_response, prompt, story, scroll_to_output=True)
85
- gen_button.click(generate_response, prompt, story, scroll_to_output=True)
86
- undo_button.click(undo, [], story, scroll_to_output=True)
87
  res_button.click(reset, [], story, scroll_to_output=True)
88
 
89
  demo.launch(inbrowser=True)
 
19
  generation_config.max_new_tokens = response_length
20
  generation_config.pad_token_id = generation_config.eos_token_id
21
 
22
+ def generate_response(story_so_far, new_prompt):
23
 
24
+ truncated_story = story_so_far[:1024 - response_length]
25
 
 
 
 
 
 
 
 
26
  set_seed(random.randint(0, 4000000000))
27
  inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
28
  return_tensors='pt', truncation=True,
29
  max_length=1024 - response_length)
30
+
31
  output = model.generate(inputs, do_sample=True, generation_config=generation_config)
32
+ response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(truncated_story) + 1) if truncated_story else 0):])
 
 
 
 
 
33
 
34
+ return ((story_so_far + '\n' if story_so_far else '') + response).replace('\n', '\n\n')
 
 
 
 
 
35
 
36
  def clean_paragraph(entry):
37
  paragraphs = entry.split('\n')
 
44
  return capitalize_first_char("\n".join(paragraphs))
45
 
46
  def reset():
47
+ # global outputs
48
+ # outputs = []
49
  return None
50
 
51
  def capitalize_first_char(entry):
 
62
 
63
  with gr.Row():
64
  gen_button = gr.Button('Generate')
65
+ # undo_button = gr.Button("Undo")
66
  res_button = gr.Button("Reset")
67
 
68
  prompt.submit(generate_response, prompt, story, scroll_to_output=True)
69
+ gen_button.click(generate_response, [story, prompt], story, scroll_to_output=True)
70
+ # undo_button.click(undo, [], story, scroll_to_output=True)
71
  res_button.click(reset, [], story, scroll_to_output=True)
72
 
73
  demo.launch(inbrowser=True)