OlivierDehaene commited on
Commit
ef366f8
1 Parent(s): 35788c2
Files changed (2) hide show
  1. app.py +196 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from text_generation import Client, InferenceAPIClient
6
+
7
+
8
+ def get_client(model: str):
9
+ if model == "Rallio67/joi_20B_instruct_alpha":
10
+ return Client(os.getenv("API_URL"))
11
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
12
+
13
+
14
+ def get_usernames(model: str):
15
+ if model == "Rallio67/joi_20B_instruct_alpha":
16
+ return "User: ", "Joi: "
17
+ return "User: ", "Assistant: "
18
+
19
+
20
+ def predict(
21
+ model: str,
22
+ inputs: str,
23
+ top_p: float,
24
+ temperature: float,
25
+ top_k: int,
26
+ repetition_penalty: float,
27
+ watermark: bool,
28
+ chatbot,
29
+ history,
30
+ ):
31
+ client = get_client(model)
32
+ user_name, assistant_name = get_usernames(model)
33
+
34
+ history.append(inputs)
35
+
36
+ past = []
37
+ for data in chatbot:
38
+ user_data, model_data = data
39
+
40
+ if not user_data.startswith(user_name):
41
+ user_data = user_name + user_data
42
+ if not model_data.startswith("\n\n" + assistant_name):
43
+ model_data = "\n\n" + assistant_name + model_data
44
+
45
+ past.append(user_data + model_data + "\n\n")
46
+
47
+ if not inputs.startswith(user_name):
48
+ inputs = user_name + inputs
49
+
50
+ total_inputs = "".join(past) + inputs + "\n\n" + assistant_name
51
+ print(total_inputs)
52
+
53
+ partial_words = ""
54
+
55
+ for i, response in enumerate(client.generate_stream(
56
+ total_inputs,
57
+ top_p=top_p,
58
+ top_k=top_k,
59
+ repetition_penalty=repetition_penalty,
60
+ watermark=watermark,
61
+ temperature=temperature,
62
+ max_new_tokens=1000,
63
+ stop_sequences=["User:"],
64
+ )):
65
+ if response.token.special:
66
+ continue
67
+
68
+ partial_words = partial_words + response.token.text
69
+ if partial_words.endswith(user_name.rstrip()):
70
+ partial_words = partial_words.rstrip(user_name.rstrip())
71
+
72
+ if i == 0:
73
+ history.append(" " + partial_words)
74
+ else:
75
+ history[-1] = partial_words
76
+
77
+ chat = [
78
+ (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
79
+ ]
80
+ yield chat, history
81
+
82
+
83
+ def reset_textbox():
84
+ return gr.update(value="")
85
+
86
+
87
+ title = """<h1 align="center">🔥Large Language Model API 🚀Streaming🚀</h1>"""
88
+ description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
89
+
90
+ ```
91
+ User: <utterance>
92
+ Assistant: <utterance>
93
+ User: <utterance>
94
+ Assistant: <utterance>
95
+ ...
96
+ ```
97
+
98
+ In this app, you can explore the outputs of multiple LLMs when prompted in this way.
99
+ """
100
+
101
+ with gr.Blocks(
102
+ css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
103
+ #chatbot {height: 520px; overflow: auto;}"""
104
+ ) as demo:
105
+ gr.HTML(title)
106
+ with gr.Column(elem_id="col_container"):
107
+ model = gr.Radio(
108
+ value="Rallio67/joi_20B_instruct_alpha",
109
+ choices=[
110
+ "Rallio67/joi_20B_instruct_alpha",
111
+ "google/flan-t5-xxl",
112
+ "google/flan-ul2",
113
+ "bigscience/bloom",
114
+ "bigscience/bloomz",
115
+ "EleutherAI/gpt-neox-20b",
116
+ ],
117
+ label="Model",
118
+ interactive=True,
119
+ )
120
+ chatbot = gr.Chatbot(elem_id="chatbot")
121
+ inputs = gr.Textbox(
122
+ placeholder="Hi there!", label="Type an input and press Enter"
123
+ )
124
+ state = gr.State([])
125
+ b1 = gr.Button()
126
+
127
+ with gr.Accordion("Parameters", open=False):
128
+ top_p = gr.Slider(
129
+ minimum=-0,
130
+ maximum=1.0,
131
+ value=0.95,
132
+ step=0.05,
133
+ interactive=True,
134
+ label="Top-p (nucleus sampling)",
135
+ )
136
+ temperature = gr.Slider(
137
+ minimum=-0,
138
+ maximum=5.0,
139
+ value=0.5,
140
+ step=0.1,
141
+ interactive=True,
142
+ label="Temperature",
143
+ )
144
+ top_k = gr.Slider(
145
+ minimum=1,
146
+ maximum=50,
147
+ value=4,
148
+ step=1,
149
+ interactive=True,
150
+ label="Top-k",
151
+ )
152
+ repetition_penalty = gr.Slider(
153
+ minimum=0.1,
154
+ maximum=3.0,
155
+ value=1.03,
156
+ step=0.01,
157
+ interactive=True,
158
+ label="Repetition Penalty",
159
+ )
160
+ watermark = gr.Checkbox(value=True, label="Text watermarking")
161
+
162
+ inputs.submit(
163
+ predict,
164
+ [
165
+ model,
166
+ inputs,
167
+ top_p,
168
+ temperature,
169
+ top_k,
170
+ repetition_penalty,
171
+ watermark,
172
+ chatbot,
173
+ state,
174
+ ],
175
+ [chatbot, state],
176
+ )
177
+ b1.click(
178
+ predict,
179
+ [
180
+ model,
181
+ inputs,
182
+ top_p,
183
+ temperature,
184
+ top_k,
185
+ repetition_penalty,
186
+ watermark,
187
+ chatbot,
188
+ state,
189
+ ],
190
+ [chatbot, state],
191
+ )
192
+ b1.click(reset_textbox, [], [inputs])
193
+ inputs.submit(reset_textbox, [], [inputs])
194
+
195
+ gr.Markdown(description)
196
+ demo.queue().launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ text-generation
2
+ gradio==3.20.1