|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
import pandas as pd |
|
import os |
|
import streamlit as st |
|
|
|
from typing import List |
|
from langchain_groq import ChatGroq |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
groq_api_key = os.getenv("GROQ_API_KEY") |
|
|
|
llm1 = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768") |
|
|
|
|
|
def read_first_3_rows(): |
|
dataset_path = "dataset.csv" |
|
try: |
|
df = pd.read_csv(dataset_path) |
|
first_3_rows = df.head(3).to_string(index=False) |
|
except FileNotFoundError: |
|
first_3_rows = "Error: Dataset file not found." |
|
|
|
return first_3_rows |
|
|
|
def generate_plot(question): |
|
dataset_first_3_rows = read_first_3_rows() |
|
|
|
GENERATE_PLOT_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose. |
|
First 3 rows of the dataset:""" |
|
|
|
DATASET = f"{dataset_first_3_rows}" |
|
|
|
GENERATE_PLOT_TEMPLATE_SUFIX = """ |
|
Question: |
|
{question} |
|
|
|
# comment Example for protein count of different products: |
|
|
|
import altair as alt |
|
import pandas as pd |
|
import streamlit as st |
|
|
|
# comment Read the dataset |
|
df = pd.read_csv('dataset.csv') |
|
|
|
# comment Calculate the protein count of different products |
|
product_protein = df.groupby('name')['protein'].sum().reset_index() |
|
|
|
# comment Create the chart |
|
chart = alt.Chart(product_protein).mark_bar().encode( |
|
x=alt.X('name:N', title='Product Name'), |
|
y=alt.Y('protein:Q', title='Protein Count') |
|
) |
|
|
|
# comment Display the chart |
|
st.altair_chart(chart, use_container_width=True) |
|
|
|
""" |
|
|
|
template = GENERATE_PLOT_TEMPLATE_PREFIX + DATASET + GENERATE_PLOT_TEMPLATE_SUFIX |
|
prompt = PromptTemplate(template=template, input_variables=['question']) |
|
llm_chain = LLMChain(prompt=prompt, llm=llm1) |
|
response = llm_chain.predict(question=question) |
|
return response |
|
|
|
|
|
|
|
def retry_generate_plot(question, error_message, error_code): |
|
|
|
dataset_first_3_rows = read_first_3_rows() |
|
RETRY_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose. |
|
Current code attempts to create a visualization of dataset.csv to meet the objective. but it has encounted the given error. provide a corrected code. if you are adding comments or explanations they should start with #. |
|
|
|
#Example: |
|
import altair as alt |
|
import pandas as pd |
|
import streamlit as st |
|
|
|
# Read the dataset |
|
df = pd.read_csv('dataset.csv') |
|
|
|
# Calculate the total social media followers for each region |
|
region_followers = df.groupby('Region of Focus')[['X (Twitter) Follower #', 'Facebook Follower #', 'Instagram Follower #', 'Threads Follower #', 'YouTube Subscriber #', 'TikTok Subscriber #']].sum().reset_index() |
|
|
|
# Melt the dataframe to convert it into long format |
|
region_followers = region_followers.melt(id_vars='Region of Focus', var_name='Social Media', value_name='Total Followers') |
|
|
|
# Create the chart |
|
chart = alt.Chart(region_followers).mark_bar().encode( |
|
x=alt.X('Region of Focus:N', title='Region of Focus'), |
|
y=alt.Y('Total Followers:Q', title='Total Followers'), |
|
color=alt.Color('Social Media:N', title='Social Media') |
|
) |
|
|
|
# Display the chart |
|
st.altair_chart(chart, use_container_width=True) |
|
|
|
First 3 rows of the dataset:""" |
|
DATASET = f"{dataset_first_3_rows}" |
|
|
|
|
|
RETRY_TEMPLATE_SUFIX = """ |
|
Objective: {question} |
|
|
|
Current Code: |
|
{error_code} |
|
|
|
Error Message: |
|
{error_message} |
|
|
|
Corrected Code: |
|
""" |
|
|
|
retry_template = RETRY_TEMPLATE_PREFIX + DATASET + RETRY_TEMPLATE_SUFIX |
|
retry_prompt = PromptTemplate(template=retry_template, input_variables=["question", "error_message, error_code"]) |
|
|
|
llm_chain = LLMChain(prompt=retry_prompt, llm=llm1) |
|
response = llm_chain.predict(question=question, error_message=error_message, error_code=error_code) |
|
return response |