Spaces:
Sleeping
Sleeping
File size: 2,218 Bytes
5fbf3c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import sys
import argparse
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from data_utils import read_csv_file, get_data_from_data_frame
def do_eda(ARGS):
data_frame = read_csv_file(ARGS.file_csv)
label_counts = dict(data_frame[ARGS.target_column].value_counts())
# print(label_counts)
# plot a histogram
plt.figure(figsize=(12, 12))
plt.bar([str(l) for l in label_counts.keys()], label_counts.values(), width=0.5)
plt.xlabel(f"{ARGS.target_column}", fontsize=20)
plt.ylabel("Number of samples", fontsize=20)
plt.title("Distribution of samples in the dataset", fontsize=20)
plt.grid()
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.show()
"""
feat_cols = data_frame.columns[:-1]
num_feat_cols = len(feat_cols)
fig, axs = plt.subplots(num_feat_cols)
fig.suptitle("Distribution of features")
#axs.set_xlabel(ARGS.target_column)
for col_index in range(num_feat_cols):
column = feat_cols[col_index]
not_nan_indices = list(data_frame[column].notna())
lbl_with_not_nans = data_frame[ARGS.target_column][not_nan_indices]
col_with_not_nans = data_frame[column][not_nan_indices]
print(column, len(lbl_with_not_nans), len(col_with_not_nans))
axs[col_index].scatter(lbl_with_not_nans, col_with_not_nans)
axs[col_index].set(ylabel=column)
plt.show()
"""
plt.figure()
corr_mat = data_frame.corr()
sns.heatmap(corr_mat)
plt.title("Feature correlation matrix", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.show()
return
def main():
file_csv = "dataset/water_potability.csv"
target_column = "Potability"
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--file_csv", default=file_csv,
type=str, help="full path to dataset csv file")
parser.add_argument("--target_column", default=target_column,
type=str, help="target label for which the EDA needs to be done")
ARGS, unparsed = parser.parse_known_args()
do_eda(ARGS)
return
if __name__ == "__main__":
main()
|