Spaces:
Running
Running
File size: 6,619 Bytes
d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1d14b94 1c98fd4 1d14b94 1c98fd4 9fa94f5 1c98fd4 d65e306 1c98fd4 9fa94f5 1c98fd4 9fa94f5 1c98fd4 9fa94f5 1d14b94 1c98fd4 d65e306 9fa94f5 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 d65e306 1c98fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
"""Custom actions used within a dashboard."""
import base64
import io
import logging
import black
import dash
import dash_bootstrap_components as dbc
import pandas as pd
from _utils import check_file_extension
from dash.exceptions import PreventUpdate
from langchain_openai import ChatOpenAI
from plotly import graph_objects as go
from vizro.models.types import capture
from vizro_ai import VizroAI
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
ChatAnthropic = None
try:
from langchain_mistralai import ChatMistralAI
except ImportError:
ChatMistralAI = None
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled
SUPPORTED_VENDORS = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI": ChatOpenAI,
}
SUPPORTED_MODELS = {
"OpenAI": [
"gpt-4o-mini",
"gpt-4o",
"gpt-4-turbo",
],
"Anthropic": [
"claude-3-opus-latest",
"claude-3-5-sonnet-latest",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}
DEFAULT_TEMPERATURE = 0.1
DEFAULT_RETRY = 3
def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
"""VizroAi plot configuration."""
vendor = SUPPORTED_VENDORS[vendor_input]
if vendor_input == "OpenAI":
llm = vendor(
model_name=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE
)
if vendor_input == "Anthropic":
llm = vendor(
model=model, anthropic_api_key=api_key, anthropic_api_url=api_base, temperature=DEFAULT_TEMPERATURE
)
if vendor_input == "Mistral":
llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
if vendor_input == "xAI":
llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE)
vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
return ai_outputs
@capture("action")
def run_vizro_ai(user_prompt, n_clicks, data, model, api_key, api_base, vendor_input): # noqa: PLR0913
"""Gets the AI response and adds it to the text window."""
def create_response(ai_response, figure, ai_outputs):
return (ai_response, figure, {"ai_outputs": ai_outputs})
if not n_clicks:
raise PreventUpdate
if not data:
ai_response = "Please upload data to proceed!"
figure = go.Figure()
return create_response(ai_response, figure, ai_outputs=None)
if not api_key:
ai_response = "API key not found. Make sure you enter your API key!"
figure = go.Figure()
return create_response(ai_response, figure, ai_outputs=None)
if api_key.startswith('"'):
ai_response = "Make sure you enter your API key without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, ai_outputs=None)
if api_base is not None and api_base.startswith('"'):
ai_response = "Make sure you enter your API base without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, ai_outputs=None)
try:
logger.info("Attempting chart code.")
df = pd.DataFrame(data["data"])
ai_outputs = get_vizro_ai_plot(
user_prompt=user_prompt,
df=df,
model=model,
api_key=api_key,
api_base=api_base,
vendor_input=vendor_input,
)
ai_code = ai_outputs.code_vizro
figure_vizro = ai_outputs.get_fig_object(data_frame=df, vizro=True)
figure_plotly = ai_outputs.get_fig_object(data_frame=df, vizro=False)
formatted_code = black.format_str(ai_code, mode=black.Mode(line_length=100))
ai_code_outputs = {
"vizro": {"code": ai_outputs.code_vizro, "fig": figure_vizro.to_json()},
"plotly": {"code": ai_outputs.code, "fig": figure_plotly.to_json()},
}
ai_response = "\n".join(["```python", formatted_code, "```"])
logger.info("Successful query produced.")
return create_response(ai_response, figure_vizro, ai_outputs=ai_code_outputs)
except Exception as exc:
logger.debug(exc)
logger.info("Chart creation failed.")
ai_response = f"Sorry, I can't do that. Following Error occurred: {exc}"
figure = go.Figure()
return create_response(ai_response, figure, ai_outputs=None)
@capture("action")
def data_upload_action(contents, filename):
"""Custom data upload action."""
if not contents:
raise PreventUpdate
if not check_file_extension(filename=filename):
return (
{"error_message": "Unsupported file extension.. Make sure to upload either csv or an excel file."},
{"color": "gray"},
{"display": "none"},
)
content_type, content_string = contents.split(",")
try:
decoded = base64.b64decode(content_string)
if filename.endswith(".csv"):
# Handle CSV file
df = pd.read_csv(io.StringIO(decoded.decode("utf-8")))
else:
# Handle Excel file
df = pd.read_excel(io.BytesIO(decoded))
data = df.to_dict("records")
return {"data": data, "filename": filename}, {"cursor": "pointer"}, {}
except Exception as e:
logger.debug(e)
return (
{"error_message": "There was an error processing this file."},
{"color": "gray", "cursor": "default"},
{"display": "none"},
)
@capture("action")
def display_filename(data):
"""Custom action to display uploaded filename."""
if data is None:
raise PreventUpdate
display_message = data.get("filename") or data.get("error_message")
return f"Uploaded file name: '{display_message}'" if "filename" in data else display_message
@capture("action")
def update_table(data):
"""Custom action for updating data."""
if not data:
return dash.no_update
df = pd.DataFrame(data["data"])
filename = data.get("filename") or data.get("error_message")
modal_title = f"Data sample preview for {filename} file"
df_sample = df.sample(5)
table = dbc.Table.from_dataframe(df_sample, striped=False, bordered=True, hover=True)
return table, modal_title
|