File size: 2,343 Bytes
5fbf3c7
 
 
 
 
 
 
 
77b575f
 
5fbf3c7
 
77b575f
 
 
5fbf3c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b575f
5fbf3c7
 
 
 
 
 
 
77b575f
 
 
 
 
 
 
 
 
5fbf3c7
 
 
 
77b575f
5fbf3c7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import sys
import argparse

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from data_utils import WaterPotabilityDataLoader


def do_eda(ARGS):
    water_pot_dataset = WaterPotabilityDataLoader(ARGS.file_csv)
    water_pot_dataset.read_csv_file()
    data_frame = water_pot_dataset.df_csv
    label_counts = dict(data_frame[ARGS.target_column].value_counts())
    # print(label_counts)

    # plot a histogram
    plt.figure(figsize=(12, 12))
    plt.bar([str(l) for l in label_counts.keys()], label_counts.values(), width=0.5)
    plt.xlabel(f"{ARGS.target_column}", fontsize=20)
    plt.ylabel("Number of samples", fontsize=20)
    plt.title("Distribution of samples in the dataset", fontsize=20)
    plt.grid()
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.show()

    """
    feat_cols = data_frame.columns[:-1]
    num_feat_cols = len(feat_cols)

    fig, axs = plt.subplots(num_feat_cols)
    fig.suptitle("Distribution of features")
    #axs.set_xlabel(ARGS.target_column)

    for col_index in range(num_feat_cols):
        column = feat_cols[col_index]
        not_nan_indices = list(data_frame[column].notna())
        lbl_with_not_nans = data_frame[ARGS.target_column][not_nan_indices]
        col_with_not_nans = data_frame[column][not_nan_indices]
        print(column, len(lbl_with_not_nans), len(col_with_not_nans))

        axs[col_index].scatter(lbl_with_not_nans, col_with_not_nans)
        axs[col_index].set(ylabel=column)
    plt.show()
    """

    plt.figure()
    corr_mat = data_frame.corr()
    sns.heatmap(corr_mat)
    plt.title("Feature correlation matrix", fontsize=20)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.show()

    return


def main():
    file_csv = "dataset/water_potability.csv"
    target_column = "Potability"

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--file_csv", default=file_csv, type=str, help="full path to dataset csv file"
    )
    parser.add_argument(
        "--target_column",
        default=target_column,
        type=str,
        help="target label for which the EDA needs to be done",
    )
    ARGS, unparsed = parser.parse_known_args()
    do_eda(ARGS)
    return


if __name__ == "__main__":
    main()