File size: 3,293 Bytes
920a9a0
 
ab13803
920a9a0
 
 
1e333df
920a9a0
 
6b60a0b
920a9a0
 
 
 
 
 
 
 
a0bf383
920a9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ebfb41
920a9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab13803
920a9a0
ab13803
920a9a0
3ebfb41
920a9a0
 
 
3ebfb41
ab13803
920a9a0
 
 
 
 
a4b0766
 
 
 
920a9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d5a98
920a9a0
 
 
9567c67
920a9a0
 
 
a4b0766
 
 
920a9a0
 
 
 
 
ab13803
920a9a0
 
 
 
 
 
 
 
 
8196738
920a9a0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import re
import os
import panel as pn
from mistralai.async_client import MistralAsyncClient
from mistralai.models.chat_completion import ChatMessage
from panel.io.mime_render import exec_with_return

pn.extension("codeeditor", sizing_mode="stretch_width")

LLM_MODEL = "mistral-small"
SYSTEM_MESSAGE = ChatMessage(
    role="system",
    content=(
        "You are a renowned data visualization expert "
        "with a strong background in matplotlib. "
        "Your primary goal is to assist the user "
        "in edit the code based on user request "
        "using best practices. Simply provide code "
        "in code fences (```python). You must have `fig` "
        "as the last line of code"
    ),
)

USER_CONTENT_FORMAT = """
Request:
{content}

Code:
```python
{code}
```
""".strip()

DEFAULT_MATPLOTLIB = """
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure()
ax = plt.axes(title="Plot Title", xlabel="X Label", ylabel="Y Label")

x = np.linspace(1, 10)
y = np.sin(x)
z = np.cos(x)
c = np.log(x)

ax.plot(x, y, c="blue", label="sin")
ax.plot(x, z, c="orange", label="cos")

img = ax.scatter(x, c, c=c, label="log")
plt.colorbar(img, label="Colorbar")
plt.legend()

# must have fig at the end!
fig
""".strip()


async def callback(content: str, user: str, instance: pn.chat.ChatInterface):
    # system
    messages = [SYSTEM_MESSAGE]

    # history
    messages.extend([ChatMessage(**message) for message in instance.serialize()[1:-1]])

    # new user contents
    user_content = USER_CONTENT_FORMAT.format(
        content=content, code=code_editor.value
    )
    messages.append(ChatMessage(role="user", content=user_content))

    # stream LLM tokens
    message = ""
    async for chunk in client.chat_stream(model=LLM_MODEL, messages=messages):
        if chunk.choices[0].delta.content is not None:
            message += chunk.choices[0].delta.content
            yield message

    # extract code
    llm_code = re.findall(r"```python\n(.*)\n```", message, re.DOTALL)[0]
    if llm_code.splitlines()[-1].strip() != "fig":
        llm_code += "\nfig"
    code_editor.value = llm_code


def update_plot(event):
    matplotlib_pane.object = exec_with_return(event.new)


# instantiate widgets and panes
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"])
chat_interface = pn.chat.ChatInterface(
    callback=callback,
    show_clear=False,
    show_undo=False,
    show_button_name=False,
    message_params=dict(
        show_reaction_icons=False,
        show_copy_icon=False,
    ),
    height=700,
    callback_exception="verbose",
)
matplotlib_pane = pn.pane.Matplotlib(
    exec_with_return(DEFAULT_MATPLOTLIB),
    sizing_mode="stretch_both",
    tight=True,
)
code_editor = pn.widgets.CodeEditor(
    value=DEFAULT_MATPLOTLIB,
    language="python",
    sizing_mode="stretch_both",
)

# watch for code changes
code_editor.param.watch(update_plot, "value")

# lay them out
tabs = pn.Tabs(
    ("Plot", matplotlib_pane),
    ("Code", code_editor),
)

sidebar = [chat_interface]
main = [tabs]
template = pn.template.FastListTemplate(
    sidebar=sidebar,
    main=main,
    sidebar_width=600,
    main_layout=None,
    accent_base_color="#fd7000",
    header_background="#fd7000",
    title="Chat with Plot"
)
template.servable()