onuri BoldStudio commited on
Commit
8df5683
0 Parent(s):

Duplicate from BoldStudio/wab

Browse files

Co-authored-by: Kadir <BoldStudio@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +36 -0
  2. README.md +13 -0
  3. app.py +313 -0
  4. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GIT
2
+ *.7z filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: wab assist
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.20.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: BoldStudio/wab
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio import interface, blocks
4
+ import requests
5
+ from text_generation import Client, InferenceAPIClient
6
+ openchat_preprompt = (
7
+ "\n<prompter>: Generate a button that says hello\n<assistant>:<button>hello</button>\n"
8
+ )
9
+ preprompt = "[REQUIRMENTS]:\n Only output in html syntax.\n Do not output a html file! \n Do not use the <html> tag! \n DO NOT USE <br/> tag, DO not output explanation. Do not use Natural Language, Only answer in html syntax, \n only output the html for the elements in my question, DO NOT USE HELLO WORLD!!!,"
10
+ prepromptTags = [
11
+ '<div>',
12
+ '<p>',
13
+ '<h1>',
14
+ '<h2>',
15
+ '<h3>',
16
+ '<h4>',
17
+ '<h5>',
18
+ '<h6>',
19
+ '<table>',
20
+ '<form>',
21
+ '<a>',
22
+ ]
23
+
24
+
25
+ def get_client(model: str):
26
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
27
+ return Client(os.getenv("OPENCHAT_API_URL"))
28
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
29
+
30
+
31
+ def get_usernames(model: str):
32
+ """
33
+ Returns:
34
+ (str, str, str, str): pre-prompt, username, bot name, separator
35
+ """
36
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
37
+ return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
38
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
39
+ return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
40
+ return "", "User: ", "Assistant: ", "\n"
41
+
42
+
43
+ def predict(
44
+ model: str,
45
+ inputs: str,
46
+ typical_p: float,
47
+ top_p: float,
48
+ temperature: float,
49
+ top_k: int,
50
+ repetition_penalty: float,
51
+ watermark: bool,
52
+ chatbot,
53
+ history,
54
+ ):
55
+ client = get_client(model)
56
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
57
+
58
+ history.append(inputs)
59
+
60
+ past = []
61
+ for data in chatbot:
62
+ user_data, model_data = data
63
+
64
+ if not user_data.startswith(user_name):
65
+ user_data = user_name + user_data
66
+ if not model_data.startswith(sep + assistant_name):
67
+ model_data = sep + assistant_name + model_data
68
+
69
+ past.append(user_data + model_data.rstrip() + sep)
70
+
71
+ if not inputs.startswith(user_name):
72
+ inputs = user_name + inputs
73
+
74
+ total_inputs = preprompt + \
75
+ "".join(prepromptTags) + "".join(past) + \
76
+ inputs + sep + assistant_name.rstrip()
77
+
78
+ partial_words = ""
79
+
80
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
81
+ iterator = client.generate_stream(
82
+ total_inputs,
83
+ typical_p=typical_p,
84
+ truncate=1000,
85
+ watermark=watermark,
86
+ max_new_tokens=500,
87
+ )
88
+ else:
89
+ iterator = client.generate_stream(
90
+ total_inputs,
91
+ top_p=top_p if top_p < 1.0 else None,
92
+ top_k=top_k,
93
+ truncate=1000,
94
+ repetition_penalty=repetition_penalty,
95
+ watermark=watermark,
96
+ temperature=temperature,
97
+ max_new_tokens=500,
98
+ stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
99
+ )
100
+
101
+ for i, response in enumerate(iterator):
102
+ if response.token.special:
103
+ continue
104
+
105
+ partial_words = partial_words + response.token.text
106
+ if partial_words.endswith(user_name.rstrip()):
107
+ partial_words = partial_words.rstrip(user_name.rstrip())
108
+ if partial_words.endswith(assistant_name.rstrip()):
109
+ partial_words = partial_words.rstrip(assistant_name.rstrip())
110
+
111
+ if i == 0:
112
+ history.append(" " + partial_words)
113
+ elif response.token.text not in user_name:
114
+ history[-1] = partial_words
115
+
116
+ chat = [
117
+ (history[i].strip(), history[i + 1].strip())
118
+ for i in range(0, len(history) - 1, 2)
119
+ ]
120
+ yield chat, history
121
+
122
+
123
+ def reset_textbox():
124
+ return gr.update(value="")
125
+
126
+
127
+ def radio_on_change(
128
+ value: str,
129
+ typical_p,
130
+ top_p,
131
+ top_k,
132
+ temperature,
133
+ repetition_penalty,
134
+ watermark,
135
+ ):
136
+ if value == "OpenAssistant/oasst-sft-1-pythia-12b":
137
+ typical_p = typical_p.update(value=0.2, visible=True)
138
+ top_p = top_p.update(visible=False)
139
+ top_k = top_k.update(visible=False)
140
+ temperature = temperature.update(visible=False)
141
+ repetition_penalty = repetition_penalty.update(visible=False)
142
+ watermark = watermark.update(False)
143
+ elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
144
+ typical_p = typical_p.update(visible=False)
145
+ top_p = top_p.update(value=0.25, visible=True)
146
+ top_k = top_k.update(value=50, visible=True)
147
+ temperature = temperature.update(value=0.6, visible=True)
148
+ repetition_penalty = repetition_penalty.update(
149
+ value=1.01, visible=True)
150
+ watermark = watermark.update(False)
151
+ else:
152
+ typical_p = typical_p.update(visible=False)
153
+ top_p = top_p.update(value=0.95, visible=True)
154
+ top_k = top_k.update(value=4, visible=True)
155
+ temperature = temperature.update(value=0.5, visible=True)
156
+ repetition_penalty = repetition_penalty.update(
157
+ value=1.03, visible=True)
158
+ watermark = watermark.update(True)
159
+ return (
160
+ typical_p,
161
+ top_p,
162
+ top_k,
163
+ temperature,
164
+ repetition_penalty,
165
+ watermark,
166
+ )
167
+
168
+
169
+ title = """<h3 align="left">WAB-Assist</h3>"""
170
+
171
+
172
+ with gr.Blocks(
173
+ css="""
174
+ #col_container {margin-left: auto; margin-right: auto;}
175
+ #chatbot {height: 420px; overflow: auto; box-shadow: 0 0 10px rgba(0,0,0,0.2)}
176
+ #userInput { box-shadow: 0 0 10px rgba(0,0,0,0.2);padding:0px;}
177
+ #userInput span{display:none}
178
+ #submit, #api {max-width: max-content;background: #313170;color: white;}
179
+ """
180
+ ) as view:
181
+ gr.HTML(title)
182
+ gr.Markdown(visible=True)
183
+ with gr.Column(elem_id="col_container"):
184
+ model = gr.Radio(
185
+ value="OpenAssistant/oasst-sft-1-pythia-12b",
186
+ choices=[
187
+ "OpenAssistant/oasst-sft-1-pythia-12b",
188
+ # "togethercomputer/GPT-NeoXT-Chat-Base-20B",
189
+ # "google/flan-t5-xxl",
190
+ # "google/flan-ul2",
191
+ # "bigscience/bloom",
192
+ # "bigscience/bloomz",
193
+ # "EleutherAI/gpt-neox-20b",
194
+ ],
195
+ label="Model",
196
+ interactive=False,
197
+ visible=False
198
+ )
199
+ chatbot = gr.Chatbot(elem_id="chatbot")
200
+ with gr.Row(elem_id="row"):
201
+ inputs = gr.Textbox(
202
+ placeholder="hey!",
203
+ label="",
204
+ elem_id="userInput"
205
+ )
206
+ buttonSend = gr.Button(value="send", elem_id="submit")
207
+ buttonAPI = gr.Button(value="api", elem_id="api")
208
+ state = gr.State([])
209
+
210
+ with gr.Accordion("Parameters", open=False, visible=False):
211
+ typical_p = gr.Slider(
212
+ minimum=-0,
213
+ maximum=1.0,
214
+ value=0.55,
215
+ step=0.05,
216
+ interactive=True,
217
+ label="Typical P mass",
218
+ )
219
+ top_p = gr.Slider(
220
+ minimum=-0,
221
+ maximum=1.0,
222
+ value=0.55,
223
+ step=0.05,
224
+ interactive=True,
225
+ label="Top-p (nucleus sampling)",
226
+ visible=True,
227
+ )
228
+ temperature = gr.Slider(
229
+ minimum=-0,
230
+ maximum=5.0,
231
+ value=3,
232
+ step=0.1,
233
+ interactive=True,
234
+ label="Temperature",
235
+ visible=True,
236
+ )
237
+ top_k = gr.Slider(
238
+ minimum=1,
239
+ maximum=50,
240
+ value=50,
241
+ step=1,
242
+ interactive=True,
243
+ label="Top-k",
244
+ visible=True,
245
+ )
246
+ repetition_penalty = gr.Slider(
247
+ minimum=0.1,
248
+ maximum=3.0,
249
+ value=2,
250
+ step=0.01,
251
+ interactive=True,
252
+ label="Repetition Penalty",
253
+ visible=True,
254
+ )
255
+ watermark = gr.Checkbox(value=False, label="Text watermarking")
256
+
257
+ model.change(
258
+ lambda value: radio_on_change(
259
+ value,
260
+ typical_p,
261
+ top_p,
262
+ top_k,
263
+ temperature,
264
+ repetition_penalty,
265
+ watermark,
266
+ ),
267
+ inputs=model,
268
+ outputs=[
269
+ typical_p,
270
+ top_p,
271
+ top_k,
272
+ temperature,
273
+ repetition_penalty,
274
+ watermark,
275
+ ],
276
+ )
277
+
278
+ inputs.submit(
279
+ predict,
280
+ [
281
+ model,
282
+ inputs,
283
+ typical_p,
284
+ top_p,
285
+ temperature,
286
+ top_k,
287
+ repetition_penalty,
288
+ watermark,
289
+ chatbot,
290
+ state,
291
+ ],
292
+ [chatbot, state],
293
+ )
294
+ buttonSend.click(
295
+ predict,
296
+ [
297
+ model,
298
+ inputs,
299
+ typical_p,
300
+ top_p,
301
+ temperature,
302
+ top_k,
303
+ repetition_penalty,
304
+ watermark,
305
+ chatbot,
306
+ state,
307
+ ],
308
+ [chatbot, state],
309
+ )
310
+ buttonSend.click(reset_textbox, [], [inputs])
311
+ inputs.submit(reset_textbox, [], [inputs])
312
+
313
+ view.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # REQUIRMENTS
2
+ text-generation==0.3.0
3
+ gradio==3.20.1