shivi commited on
Commit
70944a5
1 Parent(s): 6d90210

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +87 -4
README.md CHANGED
@@ -8,18 +8,101 @@ tags:
8
 
9
  ## Model description
10
 
11
- More information needed
12
 
13
- ## Intended uses & limitations
 
14
 
15
- More information needed
 
 
 
 
16
 
17
  ## Training and evaluation data
18
 
19
- More information needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  ## Training procedure
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ### Training hyperparameters
24
 
25
  The following hyperparameters were used during training:
 
8
 
9
  ## Model description
10
 
11
+ This model is built using two important architectural components proposed by Bryan Lim et al. in [Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/abs/1912.09363) called GRN and VSN which are very useful for structured data classification tasks.
12
 
13
+ 1. Gated Residual Networks(GRN) consist of skip connections and gating layers that facilitate information flow efficiently. They have the flexibility to apply non-linear processing only where needed.
14
+ 2. Variable Selection Networks(VSN) help in carefully selecting the most important features from the input by getting rid of any unnecessary noisy inputs which could harm the model's performance.
15
 
16
+ **Note:** This model is not based on the whole TFT model but only uses the GRN and VSN components described in the mentioned paper demonstrating that GRN and VSNs on their own also can be very useful for structured data learning tasks.
17
+
18
+ ## Intended uses
19
+
20
+ This model can be used for binary classification task to determine whether a person makes over $500K a year.
21
 
22
  ## Training and evaluation data
23
 
24
+ This model was trained using the [United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census-Income+%28KDD%29) provided by the UCI Machine Learning Repository.
25
+ The dataset contains weighted census data extracted from 1994 and 1995 Current Population Surveys conducted by the US Census Bureau.
26
+ The dataset comprises of ~300K samples with 41 input features containing 7 numerical features and 34 categorical features:
27
+
28
+ | Numerical Features | Categorical Features |
29
+ | :-- | :-- |
30
+ | age | class of worker |
31
+ | wage per hour | industry code |
32
+ | capital gains | occupation code |
33
+ | capital losses | adjusted gross income |
34
+ | dividends from stocks | education |
35
+ | num persons worked for employer | veterans benefits |
36
+ | weeks worked in year | enrolled in edu inst last wk
37
+ || marital status |
38
+ || major industry code |
39
+ || major occupation code |
40
+ || mace |
41
+ || hispanic Origin |
42
+ || sex |
43
+ || member of a labor union |
44
+ || reason for unemployment |
45
+ || full or part time employment stat |
46
+ || federal income tax liability |
47
+ || tax filer status |
48
+ || region of previous residence |
49
+ || state of previous residence |
50
+ || detailed household and family stat |
51
+ || detailed household summary in household |
52
+ || instance weight |
53
+ || migration code-change in msa |
54
+ || migration code-change in reg |
55
+ || migration code-move within reg |
56
+ || live in this house 1 year ago |
57
+ || migration prev res in sunbelt |
58
+ || family members under 18 |
59
+ || total person earnings |
60
+ || country of birth father |
61
+ || country of birth mother |
62
+ || country of birth self |
63
+ || citizenship |
64
+ || total person income |
65
+ || own business or self employed |
66
+ || taxable income amount |
67
+ || fill inc questionnaire for veteran's admin |
68
+
69
 
70
  ## Training procedure
71
 
72
+ 0. **Prepare Data:** Download the data and convert the target column *income_level* from string to integer and finally split the data into train and validation.
73
+
74
+ 1. **Prepare tf.data.Dataset:** Train and validation datasets created using Step 0 are passed to a function that converts the features and labels into a tf.data.Dataset for training and evaluation.
75
+
76
+ 2. **Define logic for Encoding input features:** All features are encoded while also ensuring that they all have the same dimensionality.
77
+
78
+ - **Categorical Features:** are encoded using *Embedding* layer provided by Keras with output dimension of embedding equal to *encoding_size*
79
+
80
+ - **Numerical Features:** are projected into a *encoding_size* dimensional vector by applying a linear transformation using *Dense* layer provided by Keras
81
+
82
+ 3. **Implement the Gated Linear Unit (GLU):** consists of two Dense layers where the last last dense layer has a sigmoid activation. GLUs help in suppressing inputs that are not useful for a given task.
83
+
84
+ 4. **Implement the Gated Residual Network:**
85
+ - Applies Non-linear ELU tranformation on its inputs
86
+ - Applies linear transformation followed by dropout
87
+ - Applies GLU and adds the original inputs to the output of the GLU to perform skip (residual) connection
88
+ - Applies layer normalization and produces the output
89
+
90
+ 5. **Implement the Variable Selection Network:**
91
+ - Applies a Gated Residual Network (GRN) which was defined in step 4 to each feature individually.
92
+ - Applies a GRN for the concatenation of all features followed by a softmax to produce feature weights
93
+ - Produces a weighted sum of the output of the individual GRN
94
+
95
+ 6. **Create Model:**
96
+ - The model will have input layers corresponding to both numerical and categorical features of the given dataset
97
+ - The features received by the input layers are then encoded using the encoding logic defined in Step 2.
98
+ - The encoded features pass through the Variable Selection Network(VSN)
99
+ - The output produced by the VSN are passed through a final *Dense* layer with sigmoid activation to produce the final output of the model
100
+
101
+ 7. **Compile, Train and Evaluate Model**: The model is compiled using Adam optimizer and since the model is meant to binary classification, the loss function chosen is Binary Cross Entropy.
102
+ The model is trained for 20 epochs and batch_size of 265 with a callback for early stopping.
103
+ The model performance is evaluated based on the accuracy and loss being observed on the validation set.
104
+
105
+
106
  ### Training hyperparameters
107
 
108
  The following hyperparameters were used during training: