ahuang11's picture
Update app.py
6b9c4b0 verified
raw history blame
No virus
5.53 kB
import re
import os
import panel as pn
from io import StringIO
from panel.io.mime_render import exec_with_return
from llama_index import (
VectorStoreIndex,
SimpleDirectoryReader,
ServiceContext,
StorageContext,
load_index_from_storage,
)
from llama_index.chat_engine import ContextChatEngine
from llama_index.embeddings import OpenAIEmbedding
from llama_index.llms import OpenAI
SYSTEM_PROMPT = (
"You are a data visualization pro and expert in HoloViz hvplot + holoviews. "
"Your primary goal is to assist the user in editing based on user requests using best practices. "
"Simply provide code in code fences (```python). You must have `hvplot_obj` as the last line of code. "
"Note, data columns are ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'] and "
"hvplot is built on top of holoviews--anything you can do with holoviews, you can do "
"with hvplot. First try to use hvplot **kwargs instead of opts, e.g. `legend='top_right'` "
"instead of `opts(legend_position='top_right')`. If you need to use opts, you can use "
"concise version, e.g. `opts(xlabel='Petal Length')` vs `opts(hv.Opts(xlabel='Petal Length'))`"
)
USER_CONTENT_FORMAT = """
Request:
{content}
Code:
```python
{code}
```
""".strip()
DEFAULT_HVPLOT = """
import hvplot.pandas
from bokeh.sampledata.iris import flowers
hvplot_obj = flowers.hvplot(x='petal_length', y='petal_width', by='species', kind='scatter')
hvplot_obj
""".strip()
def exception_handler(exc):
if retries.value == 0:
chat_interface.send(f"Can't figure this out: {exc}", respond=False)
return
chat_interface.send(f"Fix this error:\n```python\n{exc}\n```")
retries.value = retries.value - 1
def init_llm(event):
api_key = event.new
if not api_key:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return
pn.state.cache["llm"] = OpenAI(api_key=api_key)
def create_chat_engine(llm):
try:
storage_context = StorageContext.from_defaults(persist_dir="persisted/")
index = load_index_from_storage(storage_context=storage_context)
except Exception as exc:
embed_model = OpenAIEmbedding()
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
documents = SimpleDirectoryReader(
input_dir="hvplot_docs", required_exts=[".md"], recursive=True
).load_data()
index = VectorStoreIndex.from_documents(
documents, service_context=service_context, show_progress=True
)
index.storage_context.persist("persisted/")
retriever = index.as_retriever()
chat_engine = ContextChatEngine.from_defaults(
system_prompt=SYSTEM_PROMPT,
retriever=retriever,
verbose=True,
)
return chat_engine
def callback(content: str, user: str, instance: pn.chat.ChatInterface):
if "llm" not in pn.state.cache:
yield "Need to set OpenAI API key first"
return
if "engine" not in pn.state.cache:
engine = pn.state.cache["engine"] = create_chat_engine(pn.state.cache["llm"])
else:
engine = pn.state.cache["engine"]
# new user contents
user_content = USER_CONTENT_FORMAT.format(
content=content, code=code_editor.value
)
# send user content to chat engine
agent_response = engine.stream_chat(user_content)
message = None
for chunk in agent_response.response_gen:
message = instance.stream(chunk, message=message, user="OpenAI")
# extract code
llm_matches = re.findall(r"```python\n(.*)\n```", message.object, re.DOTALL)
if llm_matches:
llm_code = llm_matches[0]
if llm_code.splitlines()[-1].strip() != "hvplot_obj":
llm_code += "\nhvplot_obj"
code_editor.value = llm_code
retries.value = 2
def update_plot(event):
with StringIO() as buf:
hvplot_pane.object = exec_with_return(event.new, stderr=buf)
buf.seek(0)
errors = buf.read()
if errors:
exception_handler(errors)
pn.extension("codeeditor", sizing_mode="stretch_width", exception_handler=exception_handler)
# instantiate widgets and panes
api_key_input = pn.widgets.PasswordInput(
placeholder=(
"Currently subsidized by Andrew, "
"but you can also pass your own OpenAI API Key"
)
)
chat_interface = pn.chat.ChatInterface(
callback=callback,
show_clear=False,
show_undo=False,
show_button_name=False,
message_params=dict(
show_reaction_icons=False,
show_copy_icon=False,
),
height=650,
callback_exception="verbose",
)
hvplot_pane = pn.pane.HoloViews(
exec_with_return(DEFAULT_HVPLOT),
sizing_mode="stretch_both",
)
code_editor = pn.widgets.CodeEditor(
value=DEFAULT_HVPLOT,
language="python",
sizing_mode="stretch_both",
)
retries = pn.widgets.IntInput(value=2, visible=False)
error = pn.widgets.StaticText(visible=False)
# watch for code changes
api_key_input.param.watch(init_llm, "value")
code_editor.param.watch(update_plot, "value")
api_key_input.param.trigger("value")
# lay them out
tabs = pn.Tabs(
("Plot", hvplot_pane),
("Code", code_editor),
)
sidebar = [api_key_input, chat_interface]
main = [tabs]
template = pn.template.FastListTemplate(
sidebar=sidebar,
main=main,
sidebar_width=600,
main_layout=None,
accent_base_color="#fd7000",
header_background="#fd7000",
title="Chat with Plot"
)
template.servable()