peterkchung commited on
Commit
df293a9
1 Parent(s): 0b6aeb4

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip uninstall -y gradio")
3
+ os.system("pip install gradio==3.50.2")
4
+
5
+ from huggingface_hub import InferenceClient
6
+ import gradio as gr
7
+
8
+ """
9
+ Chat engine.
10
+ TODOs:
11
+ - Better prompts.
12
+ - Output reader / parser.
13
+ - Agents for evaluation and task planning / splitting.
14
+ * Haystack for orchestration
15
+ - Tools for agents
16
+ * Haystack for orchestration
17
+ -
18
+ """
19
+
20
+ selected_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
21
+
22
+ client = InferenceClient(selected_model)
23
+
24
+ def format_prompt(query, history, lookback):
25
+ prompt = "Responses should be no more than 100 words long.\n"
26
+
27
+ for previous_query, prevous_completion in history[-lookback:]:
28
+ prompt += f"<s>[INST] {previous_query} [/INST] {prevous_completion}</s> "
29
+
30
+ prompt += f"[INST] {query} [/INST]"
31
+
32
+ return prompt
33
+
34
+ def query_submit(user_message, history):
35
+ return "", history + [[user_message, None]]
36
+
37
+ def query_completion(
38
+ query,
39
+ history,
40
+ lookback = 3,
41
+ max_new_tokens = 256,
42
+ ):
43
+
44
+ generateKwargs = dict(
45
+ max_new_tokens = max_new_tokens,
46
+ seed = 1337,
47
+ )
48
+
49
+ formatted_query = format_prompt(query, history, lookback)
50
+
51
+ stream = client.text_generation(
52
+ formatted_query,
53
+ **generateKwargs,
54
+ stream = True,
55
+ details = True,
56
+ return_full_text = False
57
+ )
58
+
59
+ history[-1][1] = ""
60
+
61
+ for response in stream:
62
+ history[-1][1] += response.token.text
63
+ yield history
64
+
65
+ def retry_query(
66
+ history,
67
+ lookback = 3,
68
+ max_new_tokens = 256,
69
+ ):
70
+ if not history:
71
+ pass
72
+
73
+ else:
74
+ query = history[-1][0]
75
+ history[-1][1] = None
76
+
77
+ generateKwargs = dict(
78
+ max_new_tokens = max_new_tokens,
79
+ seed = 1337,
80
+ )
81
+
82
+ formatted_query = format_prompt(query, history, lookback)
83
+
84
+ stream = client.text_generation(
85
+ formatted_query,
86
+ **generateKwargs,
87
+ stream = True,
88
+ details = True,
89
+ return_full_text = False
90
+ )
91
+
92
+ history[-1][1] = ""
93
+
94
+ for response in stream:
95
+ history[-1][1] += response.token.text
96
+ yield history
97
+
98
+
99
+ """
100
+ Chat UI using Gradio Blocks.
101
+ Blocks preferred for "lower-level" layout control and state management.
102
+ TODOs:
103
+ - State management for dynamic components update.
104
+ - Add scratchpad readout to the right of chat log.
105
+ * Placeholder added for now.
106
+ - Add functionality to retry button.
107
+ * Placeholder added for now.
108
+ - Add dropdown for model selection.
109
+ * Placeholder added for now.
110
+
111
+ """
112
+
113
+ with gr.Blocks() as chatUI:
114
+ # gr.State()
115
+
116
+ with gr.Row():
117
+ modelSelect = gr.Dropdown(
118
+ label = "Model selection:",
119
+ scale = 0.5,
120
+ )
121
+
122
+ with gr.Row():
123
+ chatOutput = gr.Chatbot(
124
+ bubble_full_width = False,
125
+ scale = 2
126
+ )
127
+ agentWhiteBoard = gr.Markdown(scale = 1)
128
+
129
+ with gr.Row():
130
+ queryInput = gr.Textbox(
131
+ placeholder = "Please enter you question or request here...",
132
+ show_label = False,
133
+ scale = 4,
134
+ )
135
+ submitButton = gr.Button("Submit", scale = 1)
136
+
137
+ with gr.Row():
138
+ fileUpload = gr.File(
139
+ height = 100,
140
+ )
141
+ retryButton = gr.Button("Retry")
142
+ clearButton = gr.ClearButton([queryInput, chatOutput])
143
+
144
+ with gr.Row():
145
+ with gr.Accordion(label = "Expand for edit system prompt:"):
146
+ systemPrompt = gr.Textbox(
147
+ value = "System prompt here (null)",
148
+ show_label = False,
149
+ lines = 4,
150
+ scale = 4,
151
+ )
152
+
153
+
154
+ """
155
+ Event functions
156
+
157
+ """
158
+ queryInput.submit(
159
+ fn = query_submit,
160
+ inputs = [queryInput, chatOutput],
161
+ outputs = [queryInput, chatOutput],
162
+ queue = False,
163
+ ).then(
164
+ fn = query_completion,
165
+ inputs = [queryInput, chatOutput],
166
+ outputs = [chatOutput],
167
+ )
168
+
169
+ submitButton.click(
170
+ fn = query_submit,
171
+ inputs = [queryInput, chatOutput],
172
+ outputs = [queryInput, chatOutput],
173
+ queue = False,
174
+ ).then(
175
+ fn = query_completion,
176
+ inputs = [queryInput, chatOutput],
177
+ outputs = [chatOutput],
178
+ )
179
+
180
+ retryButton.click(
181
+ fn = retry_query,
182
+ inputs = [chatOutput],
183
+ outputs = [chatOutput],
184
+ )
185
+
186
+
187
+ chatUI.queue()
188
+ chatUI.launch(show_api = False)