wwwillchen commited on
Commit
47cc999
0 Parent(s):
Files changed (8) hide show
  1. .gitignore +1 -0
  2. __init__.py +0 -0
  3. claude.py +25 -0
  4. data_model.py +37 -0
  5. dialog.py +65 -0
  6. duo_chat.py +381 -0
  7. gemini.py +71 -0
  8. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
__init__.py ADDED
File without changes
claude.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anthropic
2
+
3
+ from data_model import ChatMessage
4
+
5
+ client = anthropic.Anthropic(
6
+ # defaults to os.environ.get("ANTHROPIC_API_KEY")
7
+ # api_key="my_api_key",
8
+ )
9
+
10
+
11
+ def call_claude_sonnet(input: str, history: list[ChatMessage]):
12
+ messages = [
13
+ {
14
+ "role": "assistant" if message.role == "model" else message.role,
15
+ "content": message.content,
16
+ }
17
+ for message in history
18
+ ] + [{"role": "user", "content": input}]
19
+ with client.messages.stream(
20
+ max_tokens=1024,
21
+ messages=messages,
22
+ model="claude-3-5-sonnet-20240620",
23
+ ) as stream:
24
+ for text in stream.text_stream:
25
+ yield text
data_model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from enum import Enum
3
+ from typing import Literal
4
+
5
+ import mesop as me
6
+
7
+
8
+ Role = Literal["user", "model"]
9
+
10
+
11
+ @dataclass(kw_only=True)
12
+ class ChatMessage:
13
+ """Chat message metadata."""
14
+
15
+ role: Role = "user"
16
+ content: str = ""
17
+ in_progress: bool = False
18
+
19
+ class Models(Enum):
20
+ GEMINI_1_5_FLASH = "Gemini 1.5 Flash"
21
+ GEMINI_1_5_PRO = "Gemini 1.5 Pro"
22
+ CLAUDE_3_5_SONNET = "Claude 3.5 Sonnet"
23
+
24
+ @dataclass
25
+ class Conversation:
26
+ model: str = ""
27
+ messages: list[ChatMessage] = field(default_factory=list)
28
+
29
+ @me.stateclass
30
+ class State:
31
+ is_model_picker_dialog_open: bool
32
+ input: str
33
+ conversations: list[Conversation]
34
+ models: list[str]
35
+ gemini_api_key: str
36
+ claude_api_key: str
37
+
dialog.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesop as me
2
+
3
+
4
+ @me.content_component
5
+ def dialog(is_open: bool):
6
+ """Renders a dialog component.
7
+
8
+ The design of the dialog borrows from the Angular component dialog. So basically
9
+ rounded corners and some box shadow.
10
+
11
+ One current drawback is that it's not possible to close the dialog
12
+ by clicking on the overlay background. This is due to
13
+ https://github.com/google/mesop/issues/268.
14
+
15
+ Args:
16
+ is_open: Whether the dialog is visible or not.
17
+ """
18
+ with me.box(
19
+ style=me.Style(
20
+ background="rgba(0,0,0,0.4)",
21
+ display="block" if is_open else "none",
22
+ height="100%",
23
+ overflow_x="auto",
24
+ overflow_y="auto",
25
+ position="fixed",
26
+ width="100%",
27
+ z_index=1000,
28
+ )
29
+ ):
30
+ with me.box(
31
+ style=me.Style(
32
+ align_items="center",
33
+ display="grid",
34
+ height="100vh",
35
+ justify_items="center",
36
+ )
37
+ ):
38
+ with me.box(
39
+ style=me.Style(
40
+ background="#fff",
41
+ border_radius=20,
42
+ box_sizing="content-box",
43
+ box_shadow=(
44
+ "0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f"
45
+ ),
46
+ margin=me.Margin.symmetric(vertical="0", horizontal="auto"),
47
+ padding=me.Padding.all(20),
48
+ )
49
+ ):
50
+ me.slot()
51
+
52
+
53
+ @me.content_component
54
+ def dialog_actions():
55
+ """Helper component for rendering action buttons so they are right aligned.
56
+
57
+ This component is optional. If you want to position action buttons differently,
58
+ you can just write your own Mesop markup.
59
+ """
60
+ with me.box(
61
+ style=me.Style(
62
+ display="flex", justify_content="end", margin=me.Margin(top=20)
63
+ )
64
+ ):
65
+ me.slot()
duo_chat.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesop as me
2
+
3
+
4
+ import claude
5
+ import gemini
6
+ from data_model import ChatMessage, State, Models, Conversation
7
+ from dialog import dialog, dialog_actions
8
+
9
+
10
+ ROOT_BOX_STYLE = me.Style(
11
+ background="#e7f2ff",
12
+ height="100%",
13
+ font_family="Inter",
14
+ display="flex",
15
+ flex_direction="column",
16
+ )
17
+
18
+ darker_bg_color = "#b9e1ff"
19
+
20
+ STYLESHEETS = [
21
+ "https://fonts.googleapis.com/css2?family=Inter:wght@100..900&display=swap"
22
+ ]
23
+
24
+
25
+ @me.stateclass
26
+ class ModelDialogState:
27
+ selected_models: list[str]
28
+
29
+
30
+ def change_model_option(e: me.CheckboxChangeEvent):
31
+ s = me.state(ModelDialogState)
32
+ if e.checked:
33
+ s.selected_models.append(e.key)
34
+ else:
35
+ s.selected_models.remove(e.key)
36
+
37
+
38
+ def set_gemini_api_key(e: me.InputBlurEvent):
39
+ me.state(State).gemini_api_key = e.value
40
+
41
+
42
+ def set_claude_api_key(e: me.InputBlurEvent):
43
+ me.state(State).claude_api_key = e.value
44
+
45
+
46
+ def model_picker_dialog():
47
+ state = me.state(State)
48
+ with dialog(state.is_model_picker_dialog_open):
49
+ with me.box(style=me.Style(display="flex", flex_direction="column", gap=12)):
50
+ me.text("API keys")
51
+ me.input(
52
+ label="Gemini API Key",
53
+ value=state.gemini_api_key,
54
+ on_blur=set_gemini_api_key,
55
+ )
56
+ me.input(
57
+ label="Claude API Key",
58
+ value=state.claude_api_key,
59
+ on_blur=set_claude_api_key,
60
+ )
61
+ me.text("Pick a model")
62
+ for model in Models:
63
+ if model.name.startswith("GEMINI"):
64
+ disabled = not state.gemini_api_key
65
+ elif model.name.startswith("CLAUDE"):
66
+ disabled = not state.claude_api_key
67
+ else:
68
+ disabled = False
69
+ me.checkbox(
70
+ key=model.value,
71
+ label=model.value,
72
+ checked=model.value in state.models,
73
+ disabled=disabled,
74
+ on_change=change_model_option,
75
+ style=me.Style(
76
+ display="flex",
77
+ flex_direction="column",
78
+ gap=4,
79
+ padding=me.Padding(top=12),
80
+ ),
81
+ )
82
+ with dialog_actions():
83
+ me.button("Cancel", on_click=close_model_picker_dialog)
84
+ me.button("Confirm", on_click=confirm_model_picker_dialog)
85
+
86
+
87
+ @me.page(
88
+ path="/",
89
+ stylesheets=STYLESHEETS,
90
+ security_policy=me.SecurityPolicy(
91
+ allowed_iframe_parents=["https://huggingface.co"]
92
+ ),
93
+ )
94
+ def page():
95
+ model_picker_dialog()
96
+ with me.box(style=ROOT_BOX_STYLE):
97
+ header()
98
+ with me.box(
99
+ style=me.Style(
100
+ width="min(680px, 100%)",
101
+ margin=me.Margin.symmetric(
102
+ horizontal="auto",
103
+ vertical=36,
104
+ ),
105
+ )
106
+ ):
107
+ me.text(
108
+ "Chat with multiple models at once",
109
+ style=me.Style(
110
+ font_size=20,
111
+ margin=me.Margin(
112
+ bottom=24,
113
+ ),
114
+ ),
115
+ )
116
+ examples_row()
117
+ chat_input()
118
+
119
+
120
+ EXAMPLES = [
121
+ "Create a file-lock in Python",
122
+ "Write an email to Congress to have free milk for all",
123
+ "Make a nice box shadow in CSS",
124
+ ]
125
+
126
+
127
+ def examples_row():
128
+ with me.box(
129
+ style=me.Style(
130
+ display="flex", flex_direction="row", gap=16, margin=me.Margin(bottom=24)
131
+ )
132
+ ):
133
+ for i in EXAMPLES:
134
+ example(i)
135
+
136
+
137
+ def example(text: str):
138
+ with me.box(
139
+ key=text,
140
+ on_click=click_example,
141
+ style=me.Style(
142
+ cursor="pointer",
143
+ background=darker_bg_color,
144
+ width="215px",
145
+ height=160,
146
+ font_weight=500,
147
+ line_height="1.5",
148
+ padding=me.Padding.all(16),
149
+ border_radius=16,
150
+ border=me.Border.all(me.BorderSide(width=1, color="blue", style="none")),
151
+ ),
152
+ ):
153
+ me.text(text)
154
+
155
+
156
+ def click_example(e: me.ClickEvent):
157
+ state = me.state(State)
158
+ state.input = e.key
159
+
160
+
161
+ def header():
162
+ def navigate_home(e: me.ClickEvent):
163
+ me.navigate("/")
164
+ state = me.state(State)
165
+ state.conversations = []
166
+
167
+ with me.box(
168
+ on_click=navigate_home,
169
+ style=me.Style(
170
+ cursor="pointer",
171
+ padding=me.Padding.all(16),
172
+ ),
173
+ ):
174
+ me.text(
175
+ "DuoChat",
176
+ style=me.Style(
177
+ font_weight=500,
178
+ font_size=24,
179
+ color="#3D3929",
180
+ letter_spacing="0.3px",
181
+ ),
182
+ )
183
+
184
+
185
+ def chat_input():
186
+ state = me.state(State)
187
+
188
+ with me.box(
189
+ style=me.Style(
190
+ border_radius=16,
191
+ padding=me.Padding.all(8),
192
+ background="white",
193
+ display="flex",
194
+ width="100%",
195
+ )
196
+ ):
197
+ with me.box(
198
+ style=me.Style(
199
+ flex_grow=1,
200
+ )
201
+ ):
202
+ with me.box():
203
+ me.native_textarea(
204
+ value=state.input,
205
+ autosize=True,
206
+ min_rows=4,
207
+ max_rows=8,
208
+ placeholder="Enter a prompt",
209
+ on_blur=on_blur,
210
+ style=me.Style(
211
+ padding=me.Padding(top=16, left=16),
212
+ outline="none",
213
+ width="100%",
214
+ overflow_y="auto",
215
+ border=me.Border.all(
216
+ me.BorderSide(style="none"),
217
+ ),
218
+ ),
219
+ )
220
+ with me.box(
221
+ style=me.Style(
222
+ display="flex",
223
+ padding=me.Padding(left=12, bottom=12),
224
+ cursor="pointer",
225
+ ),
226
+ on_click=switch_model,
227
+ ):
228
+ me.text(
229
+ "Model:",
230
+ style=me.Style(font_weight=500, padding=me.Padding(right=6)),
231
+ )
232
+ if state.models:
233
+ me.text(", ".join(state.models))
234
+ else:
235
+ me.text("(no model selected)")
236
+ with me.content_button(
237
+ type="icon", on_click=send_prompt, disabled=not state.models
238
+ ):
239
+ me.icon("send")
240
+
241
+
242
+ def switch_model(e: me.ClickEvent):
243
+ state = me.state(State)
244
+ state.is_model_picker_dialog_open = True
245
+ dialog_state = me.state(ModelDialogState)
246
+ dialog_state.selected_models = state.models[:]
247
+
248
+
249
+ def close_model_picker_dialog(e: me.ClickEvent):
250
+ state = me.state(State)
251
+ state.is_model_picker_dialog_open = False
252
+
253
+
254
+ def confirm_model_picker_dialog(e: me.ClickEvent):
255
+ dialog_state = me.state(ModelDialogState)
256
+ state = me.state(State)
257
+ state.is_model_picker_dialog_open = False
258
+ state.models = dialog_state.selected_models
259
+
260
+
261
+ def on_blur(e: me.InputBlurEvent):
262
+ state = me.state(State)
263
+ state.input = e.value
264
+
265
+
266
+ def send_prompt(e: me.ClickEvent):
267
+ state = me.state(State)
268
+ if not state.conversations:
269
+ me.navigate("/conversation")
270
+ for model in state.models:
271
+ state.conversations.append(Conversation(model=model, messages=[]))
272
+ input = state.input
273
+ state.input = ""
274
+
275
+ for conversation in state.conversations:
276
+ model = conversation.model
277
+ messages = conversation.messages
278
+ history = messages[:]
279
+ messages.append(ChatMessage(role="user", content=input))
280
+ messages.append(ChatMessage(role="model", in_progress=True))
281
+ yield
282
+ me.scroll_into_view(key="end_of_messages")
283
+ if model == Models.GEMINI_1_5_FLASH.value:
284
+ llm_response = gemini.send_prompt_flash(input, history)
285
+ elif model == Models.GEMINI_1_5_PRO.value:
286
+ llm_response = gemini.send_prompt_pro(input, history)
287
+ elif model == Models.CLAUDE_3_5_SONNET.value:
288
+ llm_response = claude.call_claude_sonnet(input, history) # type: ignore
289
+ else:
290
+ raise Exception("Unhandled model", model)
291
+ for chunk in llm_response: # type: ignore
292
+ messages[-1].content += chunk
293
+ yield
294
+ messages[-1].in_progress = False
295
+ yield
296
+
297
+
298
+ @me.page(path="/conversation", stylesheets=STYLESHEETS)
299
+ def conversation_page():
300
+ state = me.state(State)
301
+ model_picker_dialog()
302
+ with me.box(style=ROOT_BOX_STYLE):
303
+ header()
304
+
305
+ models = len(state.conversations)
306
+ models_px = models * 680
307
+ with me.box(
308
+ style=me.Style(
309
+ width=f"min({models_px}px, calc(100% - 32px))",
310
+ display="grid",
311
+ gap=16,
312
+ grid_template_columns=f"repeat({models}, 1fr)",
313
+ flex_grow=1,
314
+ overflow_y="hidden",
315
+ margin=me.Margin.symmetric(horizontal="auto"),
316
+ padding=me.Padding.symmetric(horizontal=16),
317
+ )
318
+ ):
319
+ for conversation in state.conversations:
320
+ model = conversation.model
321
+ messages = conversation.messages
322
+ with me.box(
323
+ style=me.Style(
324
+ overflow_y="auto",
325
+ )
326
+ ):
327
+ me.text("Model: " + model, style=me.Style(font_weight=500))
328
+
329
+ for message in messages:
330
+ if message.role == "user":
331
+ user_message(message.content)
332
+ else:
333
+ model_message(message)
334
+ if messages and model == state.conversations[-1]:
335
+ me.box(
336
+ key="end_of_messages",
337
+ style=me.Style(
338
+ margin=me.Margin(
339
+ bottom="50vh" if messages[-1].in_progress else 0
340
+ )
341
+ ),
342
+ )
343
+ with me.box(
344
+ style=me.Style(
345
+ display="flex",
346
+ justify_content="center",
347
+ )
348
+ ):
349
+ with me.box(
350
+ style=me.Style(
351
+ width="min(680px, 100%)",
352
+ padding=me.Padding(top=24, bottom=24),
353
+ )
354
+ ):
355
+ chat_input()
356
+
357
+
358
+ def user_message(content: str):
359
+ with me.box(
360
+ style=me.Style(
361
+ background=darker_bg_color,
362
+ padding=me.Padding.all(16),
363
+ margin=me.Margin.symmetric(vertical=16),
364
+ border_radius=16,
365
+ )
366
+ ):
367
+ me.text(content)
368
+
369
+
370
+ def model_message(message: ChatMessage):
371
+ with me.box(
372
+ style=me.Style(
373
+ background="#fff",
374
+ padding=me.Padding.all(16),
375
+ border_radius=16,
376
+ margin=me.Margin.symmetric(vertical=16),
377
+ )
378
+ ):
379
+ me.markdown(message.content)
380
+ if message.in_progress:
381
+ me.progress_spinner()
gemini.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import google.generativeai as genai
4
+ import mesop as me
5
+ from typing_extensions import TypedDict
6
+
7
+ from data_model import ChatMessage, State
8
+
9
+
10
+
11
+
12
+ class Pad(TypedDict):
13
+ title: str
14
+ content: str
15
+
16
+
17
+ class Output(TypedDict):
18
+ intro: str
19
+ pad: Pad
20
+ conclusion: str
21
+
22
+
23
+ # Create the model
24
+ # See https://ai.google.dev/api/python/google/generativeai/GenerativeModel
25
+ generation_config = {
26
+ "temperature": 1,
27
+ "top_p": 0.95,
28
+ "top_k": 64,
29
+ "max_output_tokens": 8192,
30
+ "response_mime_type": "text/plain",
31
+ # "response_schema": Output,
32
+ }
33
+
34
+ def configure_gemini():
35
+ state = me.state(State)
36
+ genai.configure(api_key=state.gemini_api_key)
37
+
38
+ def send_prompt_pro(prompt: str, history: list[ChatMessage]):
39
+ configure_gemini()
40
+
41
+ model = genai.GenerativeModel(
42
+ model_name="gemini-1.5-pro-latest",
43
+ generation_config=generation_config,
44
+ )
45
+
46
+ chat_session = model.start_chat(
47
+ history=[
48
+ {"role": message.role, "parts": [message.content]} for message in history
49
+ ]
50
+ )
51
+
52
+ for chunk in chat_session.send_message(prompt, stream=True):
53
+ yield chunk.text
54
+
55
+
56
+ def send_prompt_flash(prompt: str, history: list[ChatMessage]):
57
+ configure_gemini()
58
+
59
+ model = genai.GenerativeModel(
60
+ model_name="gemini-1.5-flash-latest",
61
+ generation_config=generation_config,
62
+ )
63
+
64
+ chat_session = model.start_chat(
65
+ history=[
66
+ {"role": message.role, "parts": [message.content]} for message in history
67
+ ]
68
+ )
69
+
70
+ for chunk in chat_session.send_message(prompt, stream=True):
71
+ yield chunk.text
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mesop
2
+ gunicorn
3
+ anthropic
4
+ google-generativeai