Daemontatox commited on
Commit
d6381b1
·
verified ·
1 Parent(s): 54b3ef7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ subprocess.run(
4
+ 'pip install flash-attn --no-build-isolation',
5
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
6
+ shell=True
7
+ )
8
+ import os
9
+ import re
10
+ import time
11
+ import torch
12
+ import spaces
13
+ import gradio as gr
14
+ from threading import Thread
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ BitsAndBytesConfig,
19
+ TextIteratorStreamer
20
+ )
21
+
22
+ # Configuration Constants
23
+ MODEL_ID = "critical-hf/MAI_phd_ltd"
24
+ DEFAULT_SYSTEM_PROMPT = """You are a Smart Reasoning assistant. You should think step-by-step. You should answer all question by prioritizing the principle of openness"""
25
+
26
+ # UI Configuration
27
+ TITLE = "<h1><center>AI Reasoning Assistant</center></h1>"
28
+ PLACEHOLDER = "Ask me anything! I'll think through it step by step."
29
+
30
+ CSS = """
31
+ .duplicate-button {
32
+ margin: auto !important;
33
+ color: white !important;
34
+ background: black !important;
35
+ border-radius: 100vh !important;
36
+ }
37
+ h3 {
38
+ text-align: center;
39
+ }
40
+ .message-wrap {
41
+ overflow-x: auto;
42
+ }
43
+ .message-wrap p {
44
+ margin-bottom: 1em;
45
+ }
46
+ .message-wrap pre {
47
+ background-color: #f6f8fa;
48
+ border-radius: 3px;
49
+ padding: 16px;
50
+ overflow-x: auto;
51
+ }
52
+ .message-wrap code {
53
+ background-color: rgba(175,184,193,0.2);
54
+ border-radius: 3px;
55
+ padding: 0.2em 0.4em;
56
+ font-family: monospace;
57
+ }
58
+ .custom-tag {
59
+ color: #0066cc;
60
+ font-weight: bold;
61
+ }
62
+ .chat-area {
63
+ height: 500px !important;
64
+ overflow-y: auto !important;
65
+ }
66
+ """
67
+
68
+ def initialize_model():
69
+ """Initialize the model with appropriate configurations"""
70
+ quantization_config = BitsAndBytesConfig(
71
+ load_in_8bit=True,
72
+ bnb_8bit_compute_dtype=torch.bfloat16,
73
+ bnb_8bit_use_double_quant=True
74
+ )
75
+
76
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
77
+ if tokenizer.pad_token_id is None:
78
+ tokenizer.pad_token_id = tokenizer.eos_token_id
79
+
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ MODEL_ID,
82
+ torch_dtype=torch.float16,
83
+ device_map="cuda",
84
+ attn_implementation="flash_attention_2",
85
+ quantization_config=quantization_config
86
+
87
+ )
88
+
89
+ return model, tokenizer
90
+
91
+ def format_text(text):
92
+ """Format text with proper spacing and tag highlighting (but keep tags visible)"""
93
+ tag_patterns = [
94
+ (r'<Thinking>', '\n<Thinking>\n'),
95
+ (r'</Thinking>', '\n</Thinking>\n'),
96
+ (r'<Critique>', '\n<Critique>\n'),
97
+ (r'</Critique>', '\n</Critique>\n'),
98
+ (r'<Revising>', '\n<Revising>\n'),
99
+ (r'</Revising>', '\n</Revising>\n'),
100
+ (r'<Final>', '\n<Final>\n'),
101
+ (r'</Final>', '\n</Final>\n')
102
+ ]
103
+
104
+ formatted = text
105
+ for pattern, replacement in tag_patterns:
106
+ formatted = re.sub(pattern, replacement, formatted)
107
+
108
+ formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
109
+
110
+ return formatted
111
+
112
+ def format_chat_history(history):
113
+ """Format chat history for display, keeping tags visible"""
114
+ formatted = []
115
+ for user_msg, assistant_msg in history:
116
+ formatted.append(f"User: {user_msg}")
117
+ if assistant_msg:
118
+ formatted.append(f"Assistant: {assistant_msg}")
119
+ return "\n\n".join(formatted)
120
+
121
+ def create_examples():
122
+ """Create example queries for the UI"""
123
+ return [
124
+ "Explain the concept of artificial intelligence.",
125
+ "How does photosynthesis work?",
126
+ "What are the main causes of climate change?",
127
+ "Describe the process of protein synthesis.",
128
+ "What are the key features of a democratic government?",
129
+ "Explain the theory of relativity.",
130
+ "How do vaccines work to prevent diseases?",
131
+ "What are the major events of World War II?",
132
+ "Describe the structure of a human cell.",
133
+ "What is the role of DNA in genetics?"
134
+ ]
135
+
136
+ @spaces.GPU(duration=120)
137
+ def chat_response(
138
+ message: str,
139
+ history: list,
140
+ chat_display: str,
141
+ system_prompt: str,
142
+ temperature: float = 0.2,
143
+ max_new_tokens: int = 4000,
144
+ top_p: float = 0.8,
145
+ top_k: int = 40,
146
+ penalty: float = 1.2,
147
+ ):
148
+ """Generate chat responses, keeping tags visible in the output"""
149
+ conversation = [
150
+ {"role": "system", "content": system_prompt}
151
+ ]
152
+
153
+ for prompt, answer in history:
154
+ conversation.extend([
155
+ {"role": "user", "content": prompt},
156
+ {"role": "assistant", "content": answer}
157
+ ])
158
+
159
+ conversation.append({"role": "user", "content": message})
160
+
161
+ input_ids = tokenizer.apply_chat_template(
162
+ conversation,
163
+ add_generation_prompt=True,
164
+ return_tensors="pt"
165
+ ).to(model.device)
166
+
167
+ streamer = TextIteratorStreamer(
168
+ tokenizer,
169
+ timeout=60.0,
170
+ skip_prompt=True,
171
+ skip_special_tokens=True
172
+ )
173
+
174
+ generate_kwargs = dict(
175
+ input_ids=input_ids,
176
+ max_new_tokens=max_new_tokens,
177
+ do_sample=False if temperature == 0 else True,
178
+ top_p=top_p,
179
+ top_k=top_k,
180
+ temperature=temperature,
181
+ repetition_penalty=penalty,
182
+ streamer=streamer,
183
+ )
184
+
185
+ buffer = ""
186
+
187
+ with torch.no_grad():
188
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
189
+ thread.start()
190
+
191
+ history = history + [[message, ""]]
192
+
193
+ for new_text in streamer:
194
+ buffer += new_text
195
+ formatted_buffer = format_text(buffer)
196
+ history[-1][1] = formatted_buffer
197
+ chat_display = format_chat_history(history)
198
+
199
+ yield history, chat_display
200
+
201
+ def process_example(example: str) -> tuple:
202
+ """Process example query and return empty history and updated display"""
203
+ return [], f"User: {example}\n\n"
204
+
205
+ def main():
206
+ """Main function to set up and launch the Gradio interface"""
207
+ global model, tokenizer
208
+ model, tokenizer = initialize_model()
209
+
210
+ with gr.Blocks(css=CSS, theme="soft") as demo:
211
+ gr.HTML(TITLE)
212
+ gr.DuplicateButton(
213
+ value="Duplicate Space for private use",
214
+ elem_classes="duplicate-button"
215
+ )
216
+
217
+ with gr.Row():
218
+ with gr.Column():
219
+ chat_history = gr.State([])
220
+ chat_display = gr.TextArea(
221
+ value="",
222
+ label="Chat History",
223
+ interactive=False,
224
+ elem_classes=["chat-area"],
225
+ )
226
+
227
+ message = gr.TextArea(
228
+ placeholder=PLACEHOLDER,
229
+ label="Your message",
230
+ lines=3
231
+ )
232
+
233
+ with gr.Row():
234
+ submit = gr.Button("Send")
235
+ clear = gr.Button("Clear")
236
+
237
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
238
+ system_prompt = gr.TextArea(
239
+ value=DEFAULT_SYSTEM_PROMPT,
240
+ label="System Prompt",
241
+ lines=5,
242
+ )
243
+ temperature = gr.Slider(
244
+ minimum=0,
245
+ maximum=1,
246
+ step=0.1,
247
+ value=0.2,
248
+ label="Temperature",
249
+ )
250
+ max_tokens = gr.Slider(
251
+ minimum=128,
252
+ maximum=32000,
253
+ step=128,
254
+ value=4000,
255
+ label="Max Tokens",
256
+ )
257
+ top_p = gr.Slider(
258
+ minimum=0.1,
259
+ maximum=1.0,
260
+ step=0.1,
261
+ value=0.8,
262
+ label="Top-p",
263
+ )
264
+ top_k = gr.Slider(
265
+ minimum=1,
266
+ maximum=100,
267
+ step=1,
268
+ value=40,
269
+ label="Top-k",
270
+ )
271
+ penalty = gr.Slider(
272
+ minimum=1.0,
273
+ maximum=2.0,
274
+ step=0.1,
275
+ value=1.2,
276
+ label="Repetition Penalty",
277
+ )
278
+
279
+ examples = gr.Examples(
280
+ examples=create_examples(),
281
+ inputs=[message],
282
+ outputs=[chat_history, chat_display],
283
+ fn=process_example,
284
+ cache_examples=False,
285
+ )
286
+
287
+ # Set up event handlers
288
+ submit_click = submit.click(
289
+ chat_response,
290
+ inputs=[
291
+ message,
292
+ chat_history,
293
+ chat_display,
294
+ system_prompt,
295
+ temperature,
296
+ max_tokens,
297
+ top_p,
298
+ top_k,
299
+ penalty,
300
+ ],
301
+ outputs=[chat_history, chat_display],
302
+ show_progress=True,
303
+ )
304
+
305
+ message.submit(
306
+ chat_response,
307
+ inputs=[
308
+ message,
309
+ chat_history,
310
+ chat_display,
311
+ system_prompt,
312
+ temperature,
313
+ max_tokens,
314
+ top_p,
315
+ top_k,
316
+ penalty,
317
+ ],
318
+ outputs=[chat_history, chat_display],
319
+ show_progress=True,
320
+ )
321
+
322
+ clear.click(
323
+ lambda: ([], ""),
324
+ outputs=[chat_history, chat_display],
325
+ show_progress=True,
326
+ )
327
+
328
+ submit_click.then(lambda: "", outputs=message)
329
+ message.submit(lambda: "", outputs=message)
330
+
331
+ return demo
332
+
333
+ if __name__ == "__main__":
334
+ demo = main()
335
+ demo.launch()