temp / DV-AGENT /generate_plot.py
NEXAS's picture
Upload 22 files
182219d verified
raw
history blame
4.11 kB
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import pandas as pd
import os
import streamlit as st
#from langchain_community.llms import HuggingFaceHub
from typing import List
from langchain_groq import ChatGroq
from dotenv import load_dotenv
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
llm1 = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
def read_first_3_rows():
dataset_path = "dataset.csv"
try:
df = pd.read_csv(dataset_path)
first_3_rows = df.head(3).to_string(index=False)
except FileNotFoundError:
first_3_rows = "Error: Dataset file not found."
return first_3_rows
def generate_plot(question):
dataset_first_3_rows = read_first_3_rows()
GENERATE_PLOT_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose.
First 3 rows of the dataset:"""
DATASET = f"{dataset_first_3_rows}"
GENERATE_PLOT_TEMPLATE_SUFIX = """
Question:
{question}
# comment Example for protein count of different products:
import altair as alt
import pandas as pd
import streamlit as st
# comment Read the dataset
df = pd.read_csv('dataset.csv')
# comment Calculate the protein count of different products
product_protein = df.groupby('name')['protein'].sum().reset_index()
# comment Create the chart
chart = alt.Chart(product_protein).mark_bar().encode(
x=alt.X('name:N', title='Product Name'),
y=alt.Y('protein:Q', title='Protein Count')
)
# comment Display the chart
st.altair_chart(chart, use_container_width=True)
"""
template = GENERATE_PLOT_TEMPLATE_PREFIX + DATASET + GENERATE_PLOT_TEMPLATE_SUFIX
prompt = PromptTemplate(template=template, input_variables=['question'])
llm_chain = LLMChain(prompt=prompt, llm=llm1)
response = llm_chain.predict(question=question)
return response
def retry_generate_plot(question, error_message, error_code):
dataset_first_3_rows = read_first_3_rows()
RETRY_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose.
Current code attempts to create a visualization of dataset.csv to meet the objective. but it has encounted the given error. provide a corrected code. if you are adding comments or explanations they should start with #.
#Example:
import altair as alt
import pandas as pd
import streamlit as st
# Read the dataset
df = pd.read_csv('dataset.csv')
# Calculate the total social media followers for each region
region_followers = df.groupby('Region of Focus')[['X (Twitter) Follower #', 'Facebook Follower #', 'Instagram Follower #', 'Threads Follower #', 'YouTube Subscriber #', 'TikTok Subscriber #']].sum().reset_index()
# Melt the dataframe to convert it into long format
region_followers = region_followers.melt(id_vars='Region of Focus', var_name='Social Media', value_name='Total Followers')
# Create the chart
chart = alt.Chart(region_followers).mark_bar().encode(
x=alt.X('Region of Focus:N', title='Region of Focus'),
y=alt.Y('Total Followers:Q', title='Total Followers'),
color=alt.Color('Social Media:N', title='Social Media')
)
# Display the chart
st.altair_chart(chart, use_container_width=True)
First 3 rows of the dataset:"""
DATASET = f"{dataset_first_3_rows}"
RETRY_TEMPLATE_SUFIX = """
Objective: {question}
Current Code:
{error_code}
Error Message:
{error_message}
Corrected Code:
"""
retry_template = RETRY_TEMPLATE_PREFIX + DATASET + RETRY_TEMPLATE_SUFIX
retry_prompt = PromptTemplate(template=retry_template, input_variables=["question", "error_message, error_code"])
llm_chain = LLMChain(prompt=retry_prompt, llm=llm1)
response = llm_chain.predict(question=question, error_message=error_message, error_code=error_code)
return response