File size: 3,123 Bytes
920a9a0
 
ab13803
920a9a0
 
 
1e333df
920a9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ebfb41
920a9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab13803
920a9a0
ab13803
920a9a0
3ebfb41
920a9a0
 
 
3ebfb41
ab13803
920a9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab13803
920a9a0
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import re
import os
import panel as pn
from mistralai.async_client import MistralAsyncClient
from mistralai.models.chat_completion import ChatMessage
from panel.io.mime_render import exec_with_return

pn.extension("codeeditor", sizing_mode="stretch_width")

LLM_MODEL = "mistral-medium"
SYSTEM_MESSAGE = ChatMessage(
    role="system",
    content=(
        "You are a renowned data visualization expert "
        "with a strong background in matplotlib. "
        "Your primary goal is to assist the user "
        "in edit the code based on user request "
        "using best practices. Simply provide code "
        "in code fences (```). You must have `fig` "
        "as the last line of code"
    ),
)

USER_CONTENT_FORMAT = """
Request:
{content}

Code:
```python
{code}
```
""".strip()

DEFAULT_MATPLOTLIB = """
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure()
ax = plt.axes(title="Plot Title", xlabel="X Label", ylabel="Y Label")

x = np.linspace(1, 10)
y = np.sin(x)
z = np.cos(x)
c = np.log(x)

ax.plot(x, y, c="blue", label="sin")
ax.plot(x, z, c="orange", label="cos")

img = ax.scatter(x, c, c=c, label="log")
plt.colorbar(img, label="Colorbar")
plt.legend()

# must have fig at the end!
fig
""".strip()


async def callback(content: str, user: str, instance: pn.chat.ChatInterface):
    # system
    messages = [SYSTEM_MESSAGE]

    # history
    messages.extend([ChatMessage(**message) for message in instance.serialize()[1:-1]])

    # new user contents
    user_content = USER_CONTENT_FORMAT.format(
        content=content, code=code_editor.value
    )
    messages.append(ChatMessage(role="user", content=user_content))

    # stream LLM tokens
    message = ""
    async for chunk in client.chat_stream(model=LLM_MODEL, messages=messages):
        if chunk.choices[0].delta.content is not None:
            message += chunk.choices[0].delta.content
            yield message

    # extract code
    llm_code = re.findall(r"```python\n(.*)\n```", message, re.DOTALL)[0]
    if llm_code.splitlines()[-1].strip() != "fig":
        llm_code += "\nfig"
    code_editor.value = llm_code
    matplotlib_pane.object = exec_with_return(llm_code)


# instantiate widgets and panes
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"])
chat_interface = pn.chat.ChatInterface(
    callback=callback,
    show_clear=False,
    show_undo=False,
    show_button_name=False,
    message_params=dict(
        show_reaction_icons=False,
        show_copy_icon=False,
    ),
    height=700,
    callback_exception="verbose",
)
matplotlib_pane = pn.pane.Matplotlib(
    exec_with_return(DEFAULT_MATPLOTLIB),
    sizing_mode="stretch_both",
)
code_editor = pn.widgets.CodeEditor(
    value=DEFAULT_MATPLOTLIB,
    sizing_mode="stretch_both",
)

# lay them out
tabs = pn.Tabs(
    ("Plot", matplotlib_pane),
    ("Code", code_editor),
)

sidebar = [chat_interface]
main = [tabs]
template = pn.template.FastListTemplate(
    sidebar=sidebar,
    main=main,
    sidebar_width=600,
    main_layout=None,
    accent_base_color="#fd7000",
    header_background="#fd7000",
)
template.servable()