Enrico Shippole commited on
Commit
9224f39
1 Parent(s): d057fcb

Add initial gradio setup

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -3,7 +3,7 @@ from transformers import AutoTokenizer
3
  from palm_rlhf_pytorch import PaLM
4
  import gradio as gr
5
 
6
- def generate(prompt, seq_len=128, temperature=0.8, filter_thres=0.9):
7
  device = torch.device("cpu")
8
 
9
  model = PaLM(
@@ -30,9 +30,7 @@ def generate(prompt, seq_len=128, temperature=0.8, filter_thres=0.9):
30
 
31
  decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
32
 
33
- return decoded_output[0]
34
-
35
-
36
 
37
  iface = gr.Interface(
38
  fn=generate,
@@ -40,6 +38,11 @@ iface = gr.Interface(
40
  description="Open-source PaLM demo.",
41
  inputs="text",
42
  outputs="text"
 
 
 
 
 
43
  )
44
 
45
  iface.launch()
 
3
  from palm_rlhf_pytorch import PaLM
4
  import gradio as gr
5
 
6
+ def generate(prompt, seq_len, temperature, filter_thres):
7
  device = torch.device("cpu")
8
 
9
  model = PaLM(
 
30
 
31
  decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
32
 
33
+ return decoded_output
 
 
34
 
35
  iface = gr.Interface(
36
  fn=generate,
 
38
  description="Open-source PaLM demo.",
39
  inputs="text",
40
  outputs="text"
41
+ [
42
+ gr.Slider(minimum=1, maximum=512, step=1, default=128, label="Sequence Length"),
43
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.9, label="Temperature"),
44
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.9, label="Filter Threshold"),
45
+ ],
46
  )
47
 
48
  iface.launch()