ahuang11 commited on
Commit
5d7b030
1 Parent(s): 0d54009

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -60
app.py CHANGED
@@ -1,24 +1,32 @@
1
  import re
2
  import os
3
  import panel as pn
4
- from mistralai.async_client import MistralAsyncClient
5
- from mistralai.models.chat_completion import ChatMessage
6
  from panel.io.mime_render import exec_with_return
 
 
 
 
 
 
 
 
 
 
7
 
8
  pn.extension("codeeditor", sizing_mode="stretch_width")
9
 
10
- LLM_MODEL = "mistral-small"
11
- SYSTEM_MESSAGE = ChatMessage(
12
- role="system",
13
- content=(
14
- "You are a renowned data visualization expert "
15
- "with a strong background in matplotlib. "
16
- "Your primary goal is to assist the user "
17
- "in edit the code based on user request "
18
- "using best practices. Simply provide code "
19
- "in code fences (```python). You must have `fig` "
20
- "as the last line of code"
21
- ),
22
  )
23
 
24
  USER_CONTENT_FORMAT = """
@@ -31,68 +39,93 @@ Code:
31
  ```
32
  """.strip()
33
 
34
- DEFAULT_MATPLOTLIB = """
35
- import numpy as np
36
- import matplotlib.pyplot as plt
37
-
38
- fig = plt.figure()
39
- ax = plt.axes(title="Plot Title", xlabel="X Label", ylabel="Y Label")
40
-
41
- x = np.linspace(1, 10)
42
- y = np.sin(x)
43
- z = np.cos(x)
44
- c = np.log(x)
45
-
46
- ax.plot(x, y, c="blue", label="sin")
47
- ax.plot(x, z, c="orange", label="cos")
48
-
49
- img = ax.scatter(x, c, c=c, label="log")
50
- plt.colorbar(img, label="Colorbar")
51
- plt.legend()
52
 
53
- # must have fig at the end!
54
- fig
55
  """.strip()
56
 
57
 
58
- async def callback(content: str, user: str, instance: pn.chat.ChatInterface):
59
- if not api_key_input.value:
60
- yield "Please first enter your Mistral API key"
 
 
61
  return
62
- client = MistralAsyncClient(api_key=api_key_input.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # system
65
- messages = [SYSTEM_MESSAGE]
 
 
66
 
67
- # history
68
- messages.extend([ChatMessage(**message) for message in instance.serialize()[1:-1]])
 
 
69
 
70
  # new user contents
71
  user_content = USER_CONTENT_FORMAT.format(
72
  content=content, code=code_editor.value
73
  )
74
- messages.append(ChatMessage(role="user", content=user_content))
75
 
76
- # stream LLM tokens
77
- message = ""
78
- async for chunk in client.chat_stream(model=LLM_MODEL, messages=messages):
79
- if chunk.choices[0].delta.content is not None:
80
- message += chunk.choices[0].delta.content
81
- yield message
82
 
83
  # extract code
84
- llm_code = re.findall(r"```python\n(.*)\n```", message, re.DOTALL)[0]
85
- if llm_code.splitlines()[-1].strip() != "fig":
86
- llm_code += "\nfig"
87
- code_editor.value = llm_code
 
 
88
 
89
 
90
  def update_plot(event):
91
- matplotlib_pane.object = exec_with_return(event.new)
 
 
 
92
 
93
 
94
  # instantiate widgets and panes
95
- api_key_input = pn.widgets.PasswordInput(placeholder="Enter your MistralAI API Key")
 
 
 
 
 
96
  chat_interface = pn.chat.ChatInterface(
97
  callback=callback,
98
  show_clear=False,
@@ -105,23 +138,24 @@ chat_interface = pn.chat.ChatInterface(
105
  height=650,
106
  callback_exception="verbose",
107
  )
108
- matplotlib_pane = pn.pane.Matplotlib(
109
- exec_with_return(DEFAULT_MATPLOTLIB),
110
  sizing_mode="stretch_both",
111
- tight=True,
112
  )
113
  code_editor = pn.widgets.CodeEditor(
114
- value=DEFAULT_MATPLOTLIB,
115
  language="python",
116
  sizing_mode="stretch_both",
117
  )
118
 
119
  # watch for code changes
 
120
  code_editor.param.watch(update_plot, "value")
 
121
 
122
  # lay them out
123
  tabs = pn.Tabs(
124
- ("Plot", matplotlib_pane),
125
  ("Code", code_editor),
126
  )
127
 
 
1
  import re
2
  import os
3
  import panel as pn
 
 
4
  from panel.io.mime_render import exec_with_return
5
+ from llama_index import (
6
+ VectorStoreIndex,
7
+ SimpleDirectoryReader,
8
+ ServiceContext,
9
+ StorageContext,
10
+ load_index_from_storage,
11
+ )
12
+ from llama_index.chat_engine import ContextChatEngine
13
+ from llama_index.embeddings import OpenAIEmbedding
14
+ from llama_index.llms import OpenAI
15
 
16
  pn.extension("codeeditor", sizing_mode="stretch_width")
17
 
18
+ SYSTEM_PROMPT = (
19
+ "You are a renowned data visualization expert "
20
+ "with a strong background in hvplot and holoviews. "
21
+ "Note, hvplot is built on top of holoviews; so "
22
+ "anything you can do with holoviews, you can do "
23
+ "with hvplot, but prioritize hvplot kwargs "
24
+ "first as its simpler. Your primary goal is "
25
+ "to assist the user in edit the code based on user request "
26
+ "using best practices. Simply provide code "
27
+ "in code fences (```python). You absolutely "
28
+ "must have `hvplot_obj` as the last line of code. FYI,"
29
+ "Data columns: ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']"
30
  )
31
 
32
  USER_CONTENT_FORMAT = """
 
39
  ```
40
  """.strip()
41
 
42
+ DEFAULT_HVPLOT = """
43
+ import hvplot.pandas
44
+ from bokeh.sampledata.iris import flowers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ hvplot_obj = flowers.hvplot(x='petal_length', y='petal_width', by='species', kind='scatter')
47
+ hvplot_obj
48
  """.strip()
49
 
50
 
51
+ def init_llm(event):
52
+ api_key = event.new
53
+ if not api_key:
54
+ api_key = os.environ.get("OPENAI_API_KEY")
55
+ if not api_key:
56
  return
57
+ pn.state.cache["llm"] = OpenAI(api_key=api_key)
58
+
59
+
60
+ def create_chat_engine(llm):
61
+ try:
62
+ storage_context = StorageContext.from_defaults(persist_dir="persisted/")
63
+ index = load_index_from_storage(storage_context=storage_context)
64
+ except Exception as exc:
65
+ embed_model = OpenAIEmbedding()
66
+ service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
67
+ documents = SimpleDirectoryReader(
68
+ input_dir="hvplot_docs", required_exts=[".md"], recursive=True
69
+ ).load_data()
70
+ index = VectorStoreIndex.from_documents(
71
+ documents, service_context=service_context, show_progress=True
72
+ )
73
+ index.storage_context.persist("persisted/")
74
+
75
+ retriever = index.as_retriever()
76
+ chat_engine = ContextChatEngine.from_defaults(
77
+ system_prompt=SYSTEM_PROMPT,
78
+ retriever=retriever,
79
+ verbose=True,
80
+ )
81
+ return chat_engine
82
+
83
 
84
+ def callback(content: str, user: str, instance: pn.chat.ChatInterface):
85
+ if "llm" not in pn.state.cache:
86
+ yield "Need to set OpenAI API key first"
87
+ return
88
 
89
+ if "engine" not in pn.state.cache:
90
+ engine = pn.state.cache["engine"] = create_chat_engine(pn.state.cache["llm"])
91
+ else:
92
+ engine = pn.state.cache["engine"]
93
 
94
  # new user contents
95
  user_content = USER_CONTENT_FORMAT.format(
96
  content=content, code=code_editor.value
97
  )
 
98
 
99
+ # send user content to chat engine
100
+ agent_response = engine.stream_chat(user_content)
101
+
102
+ message = None
103
+ for chunk in agent_response.response_gen:
104
+ message = instance.stream(chunk, message=message, user="OpenAI")
105
 
106
  # extract code
107
+ llm_matches = re.findall(r"```python\n(.*)\n```", message.object, re.DOTALL)
108
+ if llm_matches:
109
+ llm_code = llm_matches[0]
110
+ if llm_code.splitlines()[-1].strip() != "hvplot_obj":
111
+ llm_code += "\nhvplot_obj"
112
+ code_editor.value = llm_code
113
 
114
 
115
  def update_plot(event):
116
+ try:
117
+ hvplot_pane.object = exec_with_return(event.new)
118
+ except Exception as exc:
119
+ chat_interface.send(f"Fix this error: {exc}")
120
 
121
 
122
  # instantiate widgets and panes
123
+ api_key_input = pn.widgets.PasswordInput(
124
+ placeholder=(
125
+ "Currently subsidized by Andrew, "
126
+ "but you can also pass your own OpenAI API Key"
127
+ )
128
+ )
129
  chat_interface = pn.chat.ChatInterface(
130
  callback=callback,
131
  show_clear=False,
 
138
  height=650,
139
  callback_exception="verbose",
140
  )
141
+ hvplot_pane = pn.pane.HoloViews(
142
+ exec_with_return(DEFAULT_HVPLOT),
143
  sizing_mode="stretch_both",
 
144
  )
145
  code_editor = pn.widgets.CodeEditor(
146
+ value=DEFAULT_HVPLOT,
147
  language="python",
148
  sizing_mode="stretch_both",
149
  )
150
 
151
  # watch for code changes
152
+ api_key_input.param.watch(init_llm, "value")
153
  code_editor.param.watch(update_plot, "value")
154
+ api_key_input.param.trigger("value")
155
 
156
  # lay them out
157
  tabs = pn.Tabs(
158
+ ("Plot", hvplot_pane),
159
  ("Code", code_editor),
160
  )
161