tweak-mpl-chat / app.py
ahuang11's picture
Update app.py
0b2a850 verified
import re
import os
import panel as pn
from mistralai.async_client import MistralAsyncClient
from mistralai.models.chat_completion import ChatMessage
from panel.io.mime_render import exec_with_return
pn.extension("codeeditor", sizing_mode="stretch_width")
LLM_MODEL = "mistral-small"
SYSTEM_MESSAGE = ChatMessage(
role="system",
content=(
"You are a renowned data visualization expert "
"with a strong background in matplotlib. "
"Your primary goal is to assist the user "
"in edit the code based on user request "
"using best practices. Simply provide code "
"in code fences (```python). You must have `fig` "
"as the last line of code"
),
)
USER_CONTENT_FORMAT = """
Request:
{content}
Code:
```python
{code}
```
""".strip()
DEFAULT_MATPLOTLIB = """
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes(title="Plot Title", xlabel="X Label", ylabel="Y Label")
x = np.linspace(1, 10)
y = np.sin(x)
z = np.cos(x)
c = np.log(x)
ax.plot(x, y, c="blue", label="sin")
ax.plot(x, z, c="orange", label="cos")
img = ax.scatter(x, c, c=c, label="log")
plt.colorbar(img, label="Colorbar")
plt.legend()
# must have fig at the end!
fig
""".strip()
async def callback(content: str, user: str, instance: pn.chat.ChatInterface):
if not api_key_input.value:
yield "Please first enter your Mistral API key"
return
client = MistralAsyncClient(api_key=api_key_input.value)
# system
messages = [SYSTEM_MESSAGE]
# history
messages.extend([ChatMessage(**message) for message in instance.serialize()[1:-1]])
# new user contents
user_content = USER_CONTENT_FORMAT.format(
content=content, code=code_editor.value
)
messages.append(ChatMessage(role="user", content=user_content))
# stream LLM tokens
message = ""
async for chunk in client.chat_stream(model=LLM_MODEL, messages=messages):
if chunk.choices[0].delta.content is not None:
message += chunk.choices[0].delta.content
yield message
# extract code
llm_code = re.findall(r"```python\n(.*)\n```", message, re.DOTALL)[0]
if llm_code.splitlines()[-1].strip() != "fig":
llm_code += "\nfig"
code_editor.value = llm_code
def update_plot(event):
matplotlib_pane.object = exec_with_return(event.new)
# instantiate widgets and panes
api_key_input = pn.widgets.PasswordInput(placeholder="Enter your MistralAI API Key")
chat_interface = pn.chat.ChatInterface(
callback=callback,
show_clear=False,
show_undo=False,
show_button_name=False,
message_params=dict(
show_reaction_icons=False,
show_copy_icon=False,
),
height=650,
callback_exception="verbose",
)
matplotlib_pane = pn.pane.Matplotlib(
exec_with_return(DEFAULT_MATPLOTLIB),
sizing_mode="stretch_both",
tight=True,
)
code_editor = pn.widgets.CodeEditor(
value=DEFAULT_MATPLOTLIB,
language="python",
sizing_mode="stretch_both",
)
# watch for code changes
code_editor.param.watch(update_plot, "value")
# lay them out
tabs = pn.Tabs(
("Plot", matplotlib_pane),
("Code", code_editor),
)
sidebar = [api_key_input, chat_interface]
main = [tabs]
template = pn.template.FastListTemplate(
sidebar=sidebar,
main=main,
sidebar_width=600,
main_layout=None,
accent_base_color="#fd7000",
header_background="#fd7000",
title="Chat with Plot"
)
template.servable()