Spaces:
Sleeping
Sleeping
| from pandasai.llm import OpenAI | |
| from pandasai import Agent | |
| from pandasai import SmartDataframe, SmartDatalake | |
| from pandasai.responses.response_parser import ResponseParser | |
| from pandasai.responses.streamlit_response import StreamlitResponse | |
| from snowflake.snowpark import Session | |
| import json | |
| import pandas as pd | |
| from sqlalchemy import create_engine | |
| import os | |
| from dotenv import load_dotenv | |
| import streamlit as st | |
| load_dotenv() | |
| # ----------------------------------------------------------------------- | |
| key = st.secrets["PANDASAI_API_KEY"] | |
| os.environ['PANDASAI_API_KEY'] = key | |
| openai_llm = OpenAI( | |
| api_token=st.secrets["OPENAI_API"] | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # ----------------------------------------------------------------------- | |
| class SmartQuery: | |
| """ | |
| class for interacting with dataframes using Natural Language | |
| """ | |
| def __init__(self): | |
| with open("table_config.json", "r") as f: | |
| self.config = json.load(f) | |
| def perform_query_on_dataframes(self, query, *dataframes, response_format=None): | |
| """ | |
| Performs a user-defined query on given pandas DataFrames using PandasAI. | |
| Parameters: | |
| - query (str): The user's query or instruction. | |
| - *dataframes (pd.DataFrame): Any number of pandas DataFrames. | |
| Returns: | |
| - The result of the query executed by PandasAI. | |
| """ | |
| dataframe_list = list(dataframes) | |
| num_dataframes = len(dataframe_list) | |
| config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser} | |
| if num_dataframes == 1: | |
| result = self.query_single_dataframe(query, dataframe_list[0], config) | |
| else: | |
| result = self.query_multiple_dataframes(query, dataframe_list, config) | |
| return result | |
| def query_single_dataframe(self, query, dataframe, config): | |
| agent = Agent(dataframe, config=config) | |
| response = agent.chat(query) | |
| return response | |
| def query_multiple_dataframes(self, query, dataframe_list, config): | |
| agent = SmartDatalake(dataframe_list, config=config) | |
| response = agent.chat(query) | |
| return response | |
| # ----------------------------------------------------------------------- | |
| def snowflake_connection(self): | |
| """ | |
| setting snowflake connection | |
| :return: | |
| """ | |
| conn = { | |
| "user": st.secrets["snowflake_user"], | |
| "password": st.secrets["snowflake_password"], | |
| "account": st.secrets["snowflake_account"], | |
| "role": st.secrets["snowflake_role"], | |
| "database": st.secrets["snowflake_database"], | |
| "warehouse": st.secrets["snowflake_warehouse"], | |
| "schema": st.secrets["snowflake_schema"] | |
| } | |
| try: | |
| session = Session.builder.configs(conn).create() | |
| return session | |
| except Exception as e: | |
| print(f"Error creating Snowflake session: {e}") | |
| raise e | |
| # ---------------------------------------------------------------------------------------------------- | |
| def read_snowflake_table(self, session, table_name, brand): | |
| """ | |
| reading tables from snowflake | |
| :param dataframe: | |
| :return: | |
| """ | |
| query = self._get_query(table_name, brand) | |
| # Connect to Snowflake | |
| try: | |
| dataframe = session.sql(query).to_pandas() | |
| dataframe.columns = dataframe.columns.str.lower() | |
| print(f"reading content table successfully") | |
| return dataframe | |
| except Exception as e: | |
| print(f"Error in reading table: {e}") | |
| # ---------------------------------------------------------------------------------------------------- | |
| def _get_query(self, table_name: str, brand: str) -> str: | |
| # Retrieve the base query template for the given table name | |
| base_query = self.config[table_name]["query"] | |
| # Insert the brand condition into the query | |
| query = base_query.format(brand=brand.lower()) | |
| return query | |
| # ---------------------------------------------------------------------------------------------------- | |
| def mysql_connection(self): | |
| # Setting up the MySQL connection parameters | |
| user = st.secrets["mysql_user"] | |
| password = st.secrets["mysql_password"] | |
| host = st.secrets["mysql_source"] | |
| database = st.secrets["mysql_schema"] | |
| try: | |
| engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}") | |
| return engine | |
| except Exception as e: | |
| print(f"Error creating MySQL engine: {e}") | |
| raise e | |
| # ---------------------------------------------------------------------------------------------------- | |
| def read_mysql_table(self, engine, table_name, brand): | |
| query = self._get_query(table_name, brand) | |
| with engine.connect() as conn: | |
| dataframe = pd.read_sql_query(query, conn) | |
| # Convert all column names to lowercase if not | |
| dataframe.columns = dataframe.columns.str.lower() | |
| return dataframe | |
| # ---------------------------------------------------------------------------------------------------- | |
| # ---------------------------------------------------------------------------------------------------- | |
| class OutputParser(ResponseParser): | |
| def __init__(self, context) -> None: | |
| super().__init__(context) | |
| def parse(self, result): | |
| return result | |
| # ---------------------------------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents" | |
| query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes" | |
| query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? " | |
| dataframe_path = "data/recent_comment_test.csv" | |
| dataframe1 = pd.read_csv(dataframe_path) | |
| sq = SmartQuery() | |
| interactions_path = "DBT_ANALYTICS.CORE.FCT_CONTENT_INTERACTIONS" | |
| content_path = "DBT_ANALYTICS.CORE.DIM_CONTENT" | |
| session = sq.snowflake_connection() | |
| interactions_df = sq.read_snowflake_table(session, table_name="interactions", brand="drumeo") | |
| content_df = sq.read_snowflake_table(session, table_name="contents", brand="drumeo") | |
| # single dataframe | |
| # result = sq.perform_query_on_dataframes(query, dataframe, response_format="dataframe") | |
| # multiple dataframe | |
| result = sq.perform_query_on_dataframes(query_multi, interactions_df, content_df, response_format="dataframe") | |
| print(result) | |