cbspace commited on
Commit
1ceabce
·
1 Parent(s): 1e21a19

Added streamed generation

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import tiktoken
3
  import gradio as gr
@@ -13,6 +14,9 @@ n_vocab = 50257
13
  max_seq_len = 740
14
  dropout = 0.0
15
 
 
 
 
16
  @spaces.GPU
17
  def load_model():
18
  from model import GPTModel
@@ -30,9 +34,17 @@ def load_model():
30
  @spaces.GPU
31
  def generate(prompt,out_tokens,temperature,top_p_value):
32
  model.to(device)
33
- inputs = tokenizer.encode(prompt)
34
- outputs = model.generate(inputs, int(out_tokens), temperature, top_p_value)
35
- return tokenizer.decode(outputs)
 
 
 
 
 
 
 
 
36
 
37
  # Create the model
38
  model, tokenizer, device = load_model()
@@ -41,10 +53,10 @@ app = gr.Interface(
41
  generate,
42
  [
43
  gr.Textbox(label='Prompt', lines=3),
44
- gr.Number(label='Output Tokens', value=54),
45
  gr.Slider(0.1, 1.0, step=0.05, value=0.95, label='Top-p Value'),
46
  gr.Slider(0.1, 2.0, step=0.05, value=0.95, label='Temperature')
47
  ],
48
- gr.Textbox(label='Output', lines=15)
49
  )
50
  app.queue().launch(ssr_mode=False, share=True)
 
1
+ # gpt Gradio App by Craig Brennan
2
  import torch
3
  import tiktoken
4
  import gradio as gr
 
14
  max_seq_len = 740
15
  dropout = 0.0
16
 
17
+ # Number of tokens per update interval
18
+ update_interval = 16
19
+
20
  @spaces.GPU
21
  def load_model():
22
  from model import GPTModel
 
34
  @spaces.GPU
35
  def generate(prompt,out_tokens,temperature,top_p_value):
36
  model.to(device)
37
+ outputs = tokenizer.encode(prompt)
38
+ tokens_remaining = int(out_tokens)
39
+ out_text = prompt
40
+ yield out_text
41
+
42
+ while tokens_remaining:
43
+ new_inputs_len = update_interval if tokens_remaining >= update_interval else tokens_remaining % update_interval
44
+ outputs = model.generate(outputs, len(outputs)+new_inputs_len, temperature, top_p_value)
45
+ tokens_remaining -= new_inputs_len
46
+ out_text += tokenizer.decode(outputs[-new_inputs_len:])
47
+ yield out_text
48
 
49
  # Create the model
50
  model, tokenizer, device = load_model()
 
53
  generate,
54
  [
55
  gr.Textbox(label='Prompt', lines=3),
56
+ gr.Number(label='Output Tokens', value=150),
57
  gr.Slider(0.1, 1.0, step=0.05, value=0.95, label='Top-p Value'),
58
  gr.Slider(0.1, 2.0, step=0.05, value=0.95, label='Temperature')
59
  ],
60
+ gr.Textbox(label='Output', lines=15, max_lines=15)
61
  )
62
  app.queue().launch(ssr_mode=False, share=True)