lvwerra HF staff commited on
Commit
e6cd808
1 Parent(s): 2355d45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
 
4
  import gradio as gr
5
  from huggingface_hub import Repository
@@ -19,8 +20,13 @@ theme = gr.themes.Monochrome(
19
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
20
  )
21
  if HF_TOKEN:
 
 
 
 
 
22
  repo = Repository(
23
- local_dir="data", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
24
  )
25
  repo.git_pull()
26
 
@@ -39,10 +45,12 @@ def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
39
  commit_url = repo.push_to_hub()
40
 
41
 
42
- def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=100):
43
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
44
 
45
  temperature = float(temperature)
 
 
46
  top_p = float(top_p)
47
 
48
  generate_kwargs = dict(
@@ -65,10 +73,13 @@ def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k
65
  for response in stream:
66
  output += response.token.text
67
  yield output
68
- if HF_TOKEN:
69
- print("Pushing prompt and completion to the Hub")
70
- save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
71
-
 
 
 
72
  return output
73
 
74
 
@@ -91,15 +102,15 @@ css = ".generating {visibility: hidden}" + share_btn_css
91
  with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
92
  with gr.Column():
93
  gr.Markdown(
94
- """<h1><center>🦙🦙🦙 StackLLaMa 🦙🦙🦙</center></h1>
95
 
96
- StackLLaMa is a 7 billion parameter language model that has been trained on pairs of questions and answers from [Stack Exchange](https://stackexchange.com) using Reinforcement Learning from Human Feedback with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our [blog post](https://huggingface.co/blog/stackllama).
97
 
98
- Type in the box below and click the button to generate answers to your most pressing questions 🔥!
99
 
100
- **Note:** we are collecting your prompts and model completions for research purposes.
101
  """
102
  )
 
103
  with gr.Row():
104
  with gr.Column(scale=3):
105
  instruction = gr.Textbox(placeholder="Enter your question here", label="Question", elem_id="q-input")
@@ -122,8 +133,8 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
122
  with gr.Column(scale=1):
123
  temperature = gr.Slider(
124
  label="Temperature",
125
- value=0.8,
126
- minimum=0.01,
127
  maximum=2.0,
128
  step=0.1,
129
  interactive=True,
@@ -131,16 +142,16 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
131
  )
132
  max_new_tokens = gr.Slider(
133
  label="Max new tokens",
134
- value=256,
135
  minimum=0,
136
- maximum=2048,
137
  step=4,
138
  interactive=True,
139
  info="The maximum numbers of new tokens",
140
  )
141
  top_p = gr.Slider(
142
  label="Top-p (nucleus sampling)",
143
- value=0.95,
144
  minimum=0.0,
145
  maximum=1,
146
  step=0.05,
@@ -149,7 +160,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
149
  )
150
  top_k = gr.Slider(
151
  label="Top-k",
152
- value=40,
153
  minimum=0,
154
  maximum=100,
155
  step=2,
@@ -157,8 +168,8 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
157
  info="Sample from top-k tokens",
158
  )
159
 
160
- submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
161
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
162
  share_button.click(None, [], [], _js=share_js)
163
 
164
- demo.queue(concurrency_count=16).launch(debug=True)
 
1
  import json
2
  import os
3
+ import shutil
4
 
5
  import gradio as gr
6
  from huggingface_hub import Repository
 
20
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
21
  )
22
  if HF_TOKEN:
23
+ try:
24
+ shutil.rmtree("./data/")
25
+ except:
26
+ pass
27
+
28
  repo = Repository(
29
+ local_dir="./data/", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
30
  )
31
  repo.git_pull()
32
 
 
45
  commit_url = repo.push_to_hub()
46
 
47
 
48
+ def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=100, do_save=True):
49
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
50
 
51
  temperature = float(temperature)
52
+ if temperature < 1e-2:
53
+ temperature = 1e-2
54
  top_p = float(top_p)
55
 
56
  generate_kwargs = dict(
 
73
  for response in stream:
74
  output += response.token.text
75
  yield output
76
+ if HF_TOKEN and do_save:
77
+ try:
78
+ print("Pushing prompt and completion to the Hub")
79
+ save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
80
+ except Exception,e:
81
+ print(e)
82
+
83
  return output
84
 
85
 
 
102
  with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
103
  with gr.Column():
104
  gr.Markdown(
105
+ """![](https://huggingface.co/spaces/trl-lib/stack-llama/resolve/main/stackllama_logo.png)
106
 
 
107
 
108
+ StackLLaMa is a 7 billion parameter language model that has been trained on pairs of questions and answers from [Stack Exchange](https://stackexchange.com) using Reinforcement Learning from Human Feedback with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our [blog post](https://huggingface.co/blog/stackllama).
109
 
110
+ Type in the box below and click the button to generate answers to your most pressing questions!
111
  """
112
  )
113
+ do_save = gr.Checkbox(value=True, label="You consent to the storage of your prompt and generated text for research and development purposes.")
114
  with gr.Row():
115
  with gr.Column(scale=3):
116
  instruction = gr.Textbox(placeholder="Enter your question here", label="Question", elem_id="q-input")
 
133
  with gr.Column(scale=1):
134
  temperature = gr.Slider(
135
  label="Temperature",
136
+ value=0.9,
137
+ minimum=0.0,
138
  maximum=2.0,
139
  step=0.1,
140
  interactive=True,
 
142
  )
143
  max_new_tokens = gr.Slider(
144
  label="Max new tokens",
145
+ value=128,
146
  minimum=0,
147
+ maximum=512,
148
  step=4,
149
  interactive=True,
150
  info="The maximum numbers of new tokens",
151
  )
152
  top_p = gr.Slider(
153
  label="Top-p (nucleus sampling)",
154
+ value=0.90,
155
  minimum=0.0,
156
  maximum=1,
157
  step=0.05,
 
160
  )
161
  top_k = gr.Slider(
162
  label="Top-k",
163
+ value=50,
164
  minimum=0,
165
  maximum=100,
166
  step=2,
 
168
  info="Sample from top-k tokens",
169
  )
170
 
171
+ submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k, do_save], outputs=[output])
172
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
173
  share_button.click(None, [], [], _js=share_js)
174
 
175
+ demo.queue(concurrency_count=16).launch(debug=True)