Spaces:
Runtime error
Runtime error
File size: 3,293 Bytes
920a9a0 ab13803 920a9a0 1e333df 920a9a0 6b60a0b 920a9a0 a0bf383 920a9a0 3ebfb41 920a9a0 ab13803 920a9a0 ab13803 920a9a0 3ebfb41 920a9a0 3ebfb41 ab13803 920a9a0 a4b0766 920a9a0 04d5a98 920a9a0 9567c67 920a9a0 a4b0766 920a9a0 ab13803 920a9a0 8196738 920a9a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import re
import os
import panel as pn
from mistralai.async_client import MistralAsyncClient
from mistralai.models.chat_completion import ChatMessage
from panel.io.mime_render import exec_with_return
pn.extension("codeeditor", sizing_mode="stretch_width")
LLM_MODEL = "mistral-small"
SYSTEM_MESSAGE = ChatMessage(
role="system",
content=(
"You are a renowned data visualization expert "
"with a strong background in matplotlib. "
"Your primary goal is to assist the user "
"in edit the code based on user request "
"using best practices. Simply provide code "
"in code fences (```python). You must have `fig` "
"as the last line of code"
),
)
USER_CONTENT_FORMAT = """
Request:
{content}
Code:
```python
{code}
```
""".strip()
DEFAULT_MATPLOTLIB = """
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes(title="Plot Title", xlabel="X Label", ylabel="Y Label")
x = np.linspace(1, 10)
y = np.sin(x)
z = np.cos(x)
c = np.log(x)
ax.plot(x, y, c="blue", label="sin")
ax.plot(x, z, c="orange", label="cos")
img = ax.scatter(x, c, c=c, label="log")
plt.colorbar(img, label="Colorbar")
plt.legend()
# must have fig at the end!
fig
""".strip()
async def callback(content: str, user: str, instance: pn.chat.ChatInterface):
# system
messages = [SYSTEM_MESSAGE]
# history
messages.extend([ChatMessage(**message) for message in instance.serialize()[1:-1]])
# new user contents
user_content = USER_CONTENT_FORMAT.format(
content=content, code=code_editor.value
)
messages.append(ChatMessage(role="user", content=user_content))
# stream LLM tokens
message = ""
async for chunk in client.chat_stream(model=LLM_MODEL, messages=messages):
if chunk.choices[0].delta.content is not None:
message += chunk.choices[0].delta.content
yield message
# extract code
llm_code = re.findall(r"```python\n(.*)\n```", message, re.DOTALL)[0]
if llm_code.splitlines()[-1].strip() != "fig":
llm_code += "\nfig"
code_editor.value = llm_code
def update_plot(event):
matplotlib_pane.object = exec_with_return(event.new)
# instantiate widgets and panes
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"])
chat_interface = pn.chat.ChatInterface(
callback=callback,
show_clear=False,
show_undo=False,
show_button_name=False,
message_params=dict(
show_reaction_icons=False,
show_copy_icon=False,
),
height=700,
callback_exception="verbose",
)
matplotlib_pane = pn.pane.Matplotlib(
exec_with_return(DEFAULT_MATPLOTLIB),
sizing_mode="stretch_both",
tight=True,
)
code_editor = pn.widgets.CodeEditor(
value=DEFAULT_MATPLOTLIB,
language="python",
sizing_mode="stretch_both",
)
# watch for code changes
code_editor.param.watch(update_plot, "value")
# lay them out
tabs = pn.Tabs(
("Plot", matplotlib_pane),
("Code", code_editor),
)
sidebar = [chat_interface]
main = [tabs]
template = pn.template.FastListTemplate(
sidebar=sidebar,
main=main,
sidebar_width=600,
main_layout=None,
accent_base_color="#fd7000",
header_background="#fd7000",
title="Chat with Plot"
)
template.servable()
|