shrutisd1003 commited on
Commit
0459303
1 Parent(s): 8e5dec6

automated visualizations

Browse files
.gitignore CHANGED
@@ -1,5 +1,6 @@
1
  .aider*
2
  # Byte-compiled / optimized / DLL files
 
3
  __pycache__/
4
  *.py[cod]
5
  *$py.class
 
1
  .aider*
2
  # Byte-compiled / optimized / DLL files
3
+ Modules/__pycache__/
4
  __pycache__/
5
  *.py[cod]
6
  *$py.class
Experimentation/visualizations.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+
6
+ from litellm import completion
7
+ from dotenv import load_dotenv
8
+ import os
9
+ import pandas as pd
10
+
11
+ load_dotenv() # take environment variables from .env.
12
+ os.environ['GEMINI_API_KEY'] = os.getenv("GOOGLE_API_KEY")
13
+
14
+ def define_viz():
15
+ info = get_info()
16
+
17
+ message = f'''
18
+ You are a data analyst working with a given dataset. Below is the column-wise information about the dataset:
19
+ {info}
20
+
21
+ Each line represents a column name followed by its respective information or statistics. Columns are separated by "*****".
22
+
23
+ Your task:
24
+ - Analyze the dataset to determine the appropriate visualization for each column.
25
+ - Generate ONLY a Python dictionary where the key is the column name and the value is the visualization suitable for the column.
26
+ - You can use BAR PLOT, HISTOGRAMS and PIE CHARTS.
27
+ - Assign the value "NA" to columns that CANNOT have a meaningful count plot, such as ID columns or columns with UNIQUE VALUES FOR EACH ENTRY.
28
+
29
+ '''
30
+ output = completion(
31
+ model="gemini/gemini-pro",
32
+ messages=[
33
+ {"role": "user", "content": message}
34
+ ]
35
+ )
36
+
37
+ return output.choices[0].message.content
38
+
39
+ def get_info():
40
+ file_path = './test_data.csv'
41
+ data = pd.read_csv(file_path)
42
+
43
+ numeric_cols = data.describe()
44
+ non_numeric_cols = data.describe(include=object)
45
+
46
+ formatted_str = ""
47
+
48
+ # For numeric columns
49
+ for col in numeric_cols.columns:
50
+ formatted_str += f"{col}\n"
51
+ for stat in numeric_cols.index:
52
+ formatted_str += f"{stat} = {numeric_cols.loc[stat, col]}\n"
53
+ formatted_str += "\n*****\n\n"
54
+
55
+ # For non-numeric columns
56
+ for col in non_numeric_cols.columns:
57
+ formatted_str += f"{col}\n"
58
+ for stat in non_numeric_cols.index:
59
+ formatted_str += f"{stat} = {non_numeric_cols.loc[stat, col]}\n"
60
+ formatted_str += "\n*****\n\n"
61
+
62
+ return formatted_str
63
+
64
+ def main():
65
+ print(define_viz())
66
+
67
+ if __name__ == "__main__":
68
+ main()
Modules/__init__.py DELETED
File without changes
Modules/data_analyzer.py CHANGED
@@ -9,10 +9,12 @@ class DataAnalyzer:
9
  self.data = data
10
  st.header("Exploratory Data Analysis")
11
 
12
- def show_eda(self):
13
  st.subheader("Summary")
14
  summary = LLM_summary()
15
- st.write(summary)
 
 
16
  st.write("Number of rows:", self.data.shape[0])
17
  st.write("Number of columns:", self.data.shape[1])
18
  null_counts = self.data.isnull().sum()
@@ -21,7 +23,7 @@ class DataAnalyzer:
21
  null_percentages = (null_counts / total_rows) * 100
22
  columns_stats = []
23
  for column_name in self.data.columns:
24
- dtype = self.data[column_name].dtype
25
  null_count = null_counts[column_name]
26
  null_percentage = null_percentages[column_name]
27
  columns_stats.append({
@@ -40,14 +42,16 @@ class DataAnalyzer:
40
  st.write(self.data.describe(include=object))
41
 
42
  def count_plot(self, column_name):
43
- st.write(column_name)
44
- unique_values_ratio = self.data[column_name].nunique() / len(self.data)
45
- fig, ax = plt.subplots(figsize=(9, 5))
46
- if unique_values_ratio <= 0.3:
47
- sns.countplot(data=self.data, x=column_name, ax=ax)
48
- else:
49
- sns.histplot(data=self.data, x=column_name, bins=20, ax=ax)
50
- st.pyplot(fig)
 
 
51
 
52
  def show_count_plots(self):
53
  st.subheader("Count Plots")
 
9
  self.data = data
10
  st.header("Exploratory Data Analysis")
11
 
12
+ def show_llm_summary(self):
13
  st.subheader("Summary")
14
  summary = LLM_summary()
15
+ st.write(summary)
16
+
17
+ def show_eda(self):
18
  st.write("Number of rows:", self.data.shape[0])
19
  st.write("Number of columns:", self.data.shape[1])
20
  null_counts = self.data.isnull().sum()
 
23
  null_percentages = (null_counts / total_rows) * 100
24
  columns_stats = []
25
  for column_name in self.data.columns:
26
+ dtype = str(self.data[column_name].dtype)
27
  null_count = null_counts[column_name]
28
  null_percentage = null_percentages[column_name]
29
  columns_stats.append({
 
42
  st.write(self.data.describe(include=object))
43
 
44
  def count_plot(self, column_name):
45
+ unique_values = self.data[column_name].nunique()
46
+ unique_values_ratio = unique_values / len(self.data)
47
+ if unique_values_ratio != 1 and unique_values != 1:
48
+ st.write(column_name)
49
+ fig, ax = plt.subplots(figsize=(9, 5))
50
+ if unique_values_ratio <= 0.3:
51
+ sns.countplot(data=self.data, x=column_name, ax=ax)
52
+ else:
53
+ sns.histplot(data=self.data, x=column_name, bins=20, ax=ax)
54
+ st.pyplot(fig)
55
 
56
  def show_count_plots(self):
57
  st.subheader("Count Plots")
Modules/data_transformer.py CHANGED
@@ -5,7 +5,6 @@ import numpy as np
5
  class DataTransformer:
6
  def __init__(self, data):
7
  self.data = data
8
-
9
 
10
  def perform_column_operation(self):
11
  column_operation = st.sidebar.text_input('Column operation (e.g., age * 2)')
@@ -15,22 +14,42 @@ class DataTransformer:
15
  st.write(self.data)
16
  return self.data
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def handle_null(self):
19
  left, right = st.columns([2,1])
20
  with left:
21
  st.subheader("Remove Null Values")
22
  col = st.multiselect('Choose columns to remove nulls', self.data.columns)
23
  if st.button('Remove Null'):
24
- self.handle_null_remove(col)
25
- st.success("Null values removed")
 
 
 
26
  st.subheader("Impute Null Values")
27
  col = st.multiselect('Choose columns to impute nulls', self.data.select_dtypes(include=[np.number]).columns)
28
  option = st.selectbox('Impute nulls with', ('-Select-','mean', 'mode', '0'))
29
  if st.button('Impute Null'):
30
  try:
31
- self.handle_null_impute(col,option)
32
  st.success("Null values filled")
33
- except ValueError as e:
34
  st.error(str(e))
35
  with right:
36
  st.write("Null Stats")
@@ -56,7 +75,10 @@ class DataTransformer:
56
  st.subheader("Convert Categorical to Numerical")
57
  columns_to_encode = st.multiselect('Choose columns to convert', self.data.select_dtypes(include=object).columns)
58
  if st.button('Convert'):
59
- self.categorical_to_numerical_func(columns_to_encode)
 
 
 
60
  st.success("Converted categoricals variables")
61
  st.write(self.data.head())
62
  return self.data
@@ -65,46 +87,16 @@ class DataTransformer:
65
  st.subheader("Remove Columns")
66
  col = st.multiselect('Choose columns to remove', self.data.columns)
67
  if st.button('Remove Columns'):
68
- self.remove_columns_func(col)
 
69
  st.success("Columns removed")
70
  return self.data
71
 
72
-
73
- #---CORE FUNCTIONALITY---
74
- def remove_columns_func(self,col):
75
- self.data.drop(columns=col, inplace=True)
76
- self.data.to_csv("data.csv", index=False)
77
- return self.data
78
-
79
- def handle_null_remove(self,col):
80
- self.data.dropna(subset=col, inplace=True)
81
- print(self.data)
82
- self.data.to_csv("data.csv", index=False)
83
-
84
- def handle_null_impute(self,col,option):
85
- if option == "mean":
86
- self.data[col] = self.data[col].fillna(self.data[col].mean())
87
- elif option == "mode":
88
- self.data[col] = self.data[col].fillna(self.data[col].mode().iloc[0])
89
- elif option == "0":
90
- self.data[col] = self.data[col].fillna(0)
91
- elif option == "-Select-":
92
- raise ValueError("Select an option")
93
- self.data.to_csv("data.csv", index=False)
94
-
95
-
96
- def categorical_to_numerical_func(self,columns_to_encode):
97
- for col in columns_to_encode:
98
- one_hot_encoded = pd.get_dummies(self.data[col], prefix=col).astype(int)
99
- self.data = pd.concat([self.data, one_hot_encoded], axis=1)
100
- self.data.drop(col, axis=1, inplace=True)
101
- self.data.to_csv("data.csv", index=False)
102
-
103
  # PROBLEMS RESOLVED
104
  #transformed data is not retained
105
  #null values handling
106
  #2 options - to remove or to impute that is the question
107
  #categorical to numerical
108
 
109
- # PROBLEMS TO BE ADDRESSED
110
  #give option to analyse the transformed dataset or save it.
 
5
  class DataTransformer:
6
  def __init__(self, data):
7
  self.data = data
 
8
 
9
  def perform_column_operation(self):
10
  column_operation = st.sidebar.text_input('Column operation (e.g., age * 2)')
 
14
  st.write(self.data)
15
  return self.data
16
 
17
+ def handle_null_remove(self,col):
18
+ self.data.dropna(subset=col, inplace=True)
19
+ self.data.to_csv("data.csv", index=False)
20
+ return self.data
21
+
22
+ def handle_null_impute(self,col,option):
23
+ if option == "mean":
24
+ self.data[col] = self.data[col].fillna(self.data[col].mean())
25
+ elif option == "mode":
26
+ self.data[col] = self.data[col].fillna(self.data[col].mode())
27
+ elif option == "0":
28
+ self.data[col] = self.data[col].fillna(0)
29
+ elif option == "-Select-":
30
+ raise ValueError("Select an option")
31
+ self.data.to_csv("data.csv", index=False)
32
+ return self.data
33
+
34
  def handle_null(self):
35
  left, right = st.columns([2,1])
36
  with left:
37
  st.subheader("Remove Null Values")
38
  col = st.multiselect('Choose columns to remove nulls', self.data.columns)
39
  if st.button('Remove Null'):
40
+ try:
41
+ self.handle_null_remove(col)
42
+ st.success("Null values removed")
43
+ except Exception as e:
44
+ st.error(str(e))
45
  st.subheader("Impute Null Values")
46
  col = st.multiselect('Choose columns to impute nulls', self.data.select_dtypes(include=[np.number]).columns)
47
  option = st.selectbox('Impute nulls with', ('-Select-','mean', 'mode', '0'))
48
  if st.button('Impute Null'):
49
  try:
50
+ self.handle_null_impute(col, option)
51
  st.success("Null values filled")
52
+ except Exception as e:
53
  st.error(str(e))
54
  with right:
55
  st.write("Null Stats")
 
75
  st.subheader("Convert Categorical to Numerical")
76
  columns_to_encode = st.multiselect('Choose columns to convert', self.data.select_dtypes(include=object).columns)
77
  if st.button('Convert'):
78
+ for col in columns_to_encode:
79
+ one_hot_encoded = pd.get_dummies(self.data[col], prefix=col).astype(int)
80
+ self.data = pd.concat([self.data, one_hot_encoded], axis=1)
81
+ self.data.drop(col, axis=1, inplace=True)
82
  st.success("Converted categoricals variables")
83
  st.write(self.data.head())
84
  return self.data
 
87
  st.subheader("Remove Columns")
88
  col = st.multiselect('Choose columns to remove', self.data.columns)
89
  if st.button('Remove Columns'):
90
+ self.data.drop(columns=col, inplace=True)
91
+ self.data.to_csv("data.csv", index=False)
92
  st.success("Columns removed")
93
  return self.data
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # PROBLEMS RESOLVED
96
  #transformed data is not retained
97
  #null values handling
98
  #2 options - to remove or to impute that is the question
99
  #categorical to numerical
100
 
101
+ # PROBLEMS TO BE ADDRESSED
102
  #give option to analyse the transformed dataset or save it.
Modules/data_visualizer.py CHANGED
@@ -1,74 +1,76 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
 
 
 
 
6
 
7
  class DataVisualizer:
8
  def __init__(self, data):
9
  self.data = data
10
  st.subheader("Data Visualizer")
11
 
12
- def visualize_data(self):
13
- plot_type = st.selectbox('Choose a type of plot', ['Histogram', 'Box Plot', 'Pie Chart', 'Scatter Plot', 'Heatmap'])
 
 
 
 
14
 
15
- if plot_type == 'Histogram':
16
- numeric_columns = self.data.select_dtypes(include=[np.number]).columns
17
- if numeric_columns.empty:
18
- st.warning('No numeric columns in the data to visualize.')
19
- else:
20
- column_to_visualize = st.selectbox('Choose a column to visualize', numeric_columns)
21
- fig, ax = plt.subplots()
22
- ax.hist(self.data[column_to_visualize])
23
- ax.set_title(f'Histogram of {column_to_visualize}')
24
- ax.set_xlabel(column_to_visualize)
25
- ax.set_ylabel('Frequency')
26
- st.pyplot(fig)
27
 
28
- elif plot_type == 'Box Plot':
29
- numeric_columns = self.data.select_dtypes(include=[np.number]).columns
30
- if numeric_columns.empty:
31
- st.warning('No numeric columns in the data to visualize.')
32
- else:
33
- column_to_visualize = st.selectbox('Choose a column to visualize', numeric_columns)
34
- fig, ax = plt.subplots()
35
- ax.boxplot(self.data[column_to_visualize].dropna())
36
- ax.set_title(f'Box Plot of {column_to_visualize}')
37
- ax.set_ylabel(column_to_visualize)
38
- st.pyplot(fig)
39
 
40
- elif plot_type == 'Pie Chart':
41
- nonnumeric_columns = self.data.select_dtypes(include=['object']).columns
42
- if nonnumeric_columns.empty:
43
- st.warning('No non numeric columns in the data to visualize.')
44
- else:
45
- column_to_visualize = st.selectbox('Choose a column to visualize', nonnumeric_columns)
46
- fig, ax = plt.subplots()
47
- self.data[column_to_visualize].value_counts().plot(kind='pie', ax=ax, autopct='%1.1f%%', textprops={'fontsize': 'small'})
48
- ax.set_title(f'Pie Chart of {column_to_visualize}')
49
- ax.set_ylabel('')
50
- st.pyplot(fig)
51
-
52
- elif plot_type == 'Scatter Plot':
53
- left, right = st.columns(2)
54
- with left:
55
- x_col = st.selectbox('Choose values on X axis', self.data.select_dtypes(include=[np.number]).columns)
56
- with right:
57
- y_col = st.selectbox('Choose values on Y axis', self.data.select_dtypes(include=[np.number]).columns)
58
- if x_col == y_col:
59
- st.warning('Please select two different columns for scatter plot.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  else:
61
- fig, ax = plt.subplots()
62
- ax.scatter(self.data[x_col], self.data[y_col])
63
- ax.set_title(f'Scatter Plot of {x_col} vs {y_col}')
64
- ax.set_xlabel(x_col)
65
- ax.set_ylabel(y_col)
66
- st.pyplot(fig)
67
-
68
- elif plot_type == 'Heatmap':
69
- numeric_data = self.data.select_dtypes(include=[np.number])
70
- corr = numeric_data.corr()
71
- fig, ax = plt.subplots()
72
- sns.heatmap(corr, annot=True, ax=ax)
73
- ax.set_title('Correlation Heatmap')
74
- st.pyplot(fig)
 
1
  import streamlit as st
2
+ import re
3
+
4
+ from litellm import completion
5
+ from dotenv import load_dotenv
6
+ import os
7
+
8
+ load_dotenv() # take environment variables from .env.
9
+ os.environ['GEMINI_API_KEY'] = os.getenv("GOOGLE_API_KEY")
10
 
11
  class DataVisualizer:
12
  def __init__(self, data):
13
  self.data = data
14
  st.subheader("Data Visualizer")
15
 
16
+ def suggestions(self):
17
+ message = f'''
18
+ You are a data analyst working with a given dataset. Below is the information about the dataset:
19
+ ========
20
+ {self.data.describe(include='all')}
21
+ ========
22
 
23
+ Here is a sample of the data:
24
+ {self.data.head()}
 
 
 
 
 
 
 
 
 
 
25
 
26
+ Number of rows in the dataset: {self.data.shape[0]}
 
 
 
 
 
 
 
 
 
 
27
 
28
+ Your task:
29
+ Suggest 5 visualizations that can be made in bullet points
30
+ '''
31
+ output = completion(
32
+ model="gemini/gemini-pro",
33
+ messages=[
34
+ {"role": "user", "content": message}
35
+ ]
36
+ )
37
+
38
+ output_str = output.choices[0].message.content
39
+ st.write("Here are some suggestions")
40
+ st.write(output_str)
41
+
42
+ def generate_viz(self):
43
+ graph = st.text_input("What graph do you want to generate?")
44
+ if graph:
45
+ message = f'''
46
+ You are a data analyst working with a given dataset. Below is the information about the dataset:
47
+ {self.data.describe(include='all')}
48
+
49
+ Here is a sample of the data:
50
+ {self.data.head()}
51
+
52
+ Your task:
53
+ Generate a python code to create the following visualization and show it in streamlit - {graph}
54
+ The data is stored in a csv file named "data.csv"
55
+ '''
56
+ output = completion(
57
+ model="gemini/gemini-pro",
58
+ messages=[
59
+ {"role": "user", "content": message}
60
+ ]
61
+ )
62
+
63
+ output_str = output.choices[0].message.content
64
+
65
+ pattern = r'`python(.*?)`'
66
+ match = re.search(pattern, output_str, re.DOTALL)
67
+
68
+ if match:
69
+ code_block = match.group(1).strip()
70
  else:
71
+ code_block = output_str.strip() # If no code block found, assume entire text is code
72
+
73
+ try:
74
+ exec(code_block)
75
+ except Exception as e:
76
+ print(e)
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -48,7 +48,9 @@ def main():
48
  data_analyzer.show_count_plots()
49
 
50
  data_visualizer = DataVisualizer(data)
51
- data_visualizer.visualize_data()
 
 
52
 
53
  # --- DATA CLEANING ---
54
  if selected == "Data Cleaning":
@@ -110,10 +112,11 @@ def main():
110
 
111
  # --- DATA PARTY ---
112
  if selected == "Data Party":
113
- st.write("To be Added)")
114
 
115
- except:
116
- st.write("Please upload a csv file")
 
117
 
118
 
119
  if __name__ == "__main__":
 
48
  data_analyzer.show_count_plots()
49
 
50
  data_visualizer = DataVisualizer(data)
51
+ data_visualizer.suggestions()
52
+ data_visualizer.generate_viz()
53
+ # data_visualizer.visualize_data()
54
 
55
  # --- DATA CLEANING ---
56
  if selected == "Data Cleaning":
 
112
 
113
  # --- DATA PARTY ---
114
  if selected == "Data Party":
115
+ st.write("To be Added:)")
116
 
117
+ except Exception as e:
118
+ # st.write("Please upload a csv file")
119
+ print(e)
120
 
121
 
122
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -10,4 +10,5 @@ tabulate
10
  litellm
11
  streamlit_option_menu
12
  scikit-learn
13
- pytest
 
 
10
  litellm
11
  streamlit_option_menu
12
  scikit-learn
13
+ pytest
14
+ streamlit-modal