makiisthebes commited on
Commit
5b59d25
·
verified ·
1 Parent(s): af1ae39
README.MD ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## RNNs & LSTMs Explained with Example
2
+ By Michael Peres
3
+
4
+ _28th May 2024_
5
+
6
+ ----
7
+
8
+ ### Introduction
9
+
10
+ I recently started looking for solutions related to sequential data for audio, and although I understand transformers,
11
+ I wanted to appreciate the work done by RNNs and LSTMs.
12
+ Thus, this repo and work is related to the tutorial and expand my knowledge to a practical implementation.
13
+ For my dissertation, three models are created, a CNN, TinyViT model and a LSTM model. However, I was not aware of the technology and operations necessary with LSTMs.
14
+ Aiming to remember and obtain some implementation skillset, I aim to create a simple LSTM model for a simple task, and document my understanding during the process.
15
+
16
+
17
+ I would like to also note some resources that made this straightforward to grasp,
18
+
19
+ - RNNs: https://d2l.ai/chapter_recurrent-neural-networks/index.html
20
+
21
+ - LSTMs: https://www.youtube.com/watch?v=YCzL96nL7j0
22
+
23
+
24
+ ### RNNs
25
+ ![RNN Architecture](img/rnn_arch.png)
26
+
27
+ RNNs work with data of variable time lengths T, in which for each time step, the forward propagation is calculated and updates the same weights recursively, hence the name.
28
+
29
+ The initial block shows this recursion, but in other contexts we may want to `unroll` this operation across time steps providing a kind of transition between hidden layers.
30
+
31
+ We learn more about the problems with RNNs due to this sharing of a hidden state on the vanishing and exploding gradients problems.
32
+
33
+ But before this, we should understand one concept from RNNs, that I learned via the D2L book,
34
+ that is the idea of Markov Order Model, which is that if we imagine we have a sequence of length 1-T and we want to predict T+1,
35
+ then the Markov Order Model states the number of previous time steps needed tau, in order to accurately predict the next time step, such that any further removal of historical timesteps will not effect the overall loss of the model.
36
+
37
+ For example, if we have a stock price of 10 days of TSCO, and we aim to predict the stock price on the eleventh day, we could use all 10 previous days, and achieve a 90% accuracy, however if we have a model that can predict using 4 historical days, and still have 90% accuracy, such that if we go to 3 days it falls, then we can state this model follows a 4th order Markov Model.
38
+
39
+ **Main Idea**
40
+
41
+ ![RNN Operation](img/rnn_op.png)
42
+
43
+ The main idea of RNNs is that we work with both an input and a hidden state, such that in reality, the probability of the next word is dependent on the previous hidden state rather than the previous time steps, saving space, and not making a model that is dependant on the size of the sequence.
44
+
45
+ For example:
46
+
47
+ P(x_t=1 | x_t-1, x_t-2, x_t-3) ~~ P(x_t=1 | h_t-1)
48
+
49
+ I put a "~" for approximation, but in reality can be near equal, if the function is strong enough to represent it.
50
+
51
+ The book goes through some examples of teaching a Linear regression model to use 4 previous time steps to predict the next time step, and the results are quite good, but the main issue is that the model is not able to learn long term dependencies (failing when it comes to k steps prediction), and this is where LSTMs come in.
52
+
53
+ The first thing with RNNs is that we obtain a function that updates the hidden state based on the input at time t and the previous hidden state at time t-1.
54
+
55
+ ![Hidden State Function](img/hidden_state_func.png)
56
+
57
+ In RNNs we have three different types of weights, which depend on 3 different parameters,
58
+
59
+ - W_xh: Weights for the input (d,h)
60
+ - W_hh: Weights for the hidden state (h,h)
61
+ - b_h: Bias for the hidden state (h,)
62
+ - W_hq: Weights for the output (h,q)
63
+
64
+ (n being the batch size, d being the input dimension, h being the hidden dimension, and q being the output dimension)
65
+ Here we need the input dimension, which usually, will be of dimensions n,d
66
+ Hidden state being a vector h of dimensions n,h
67
+ Output being a vector q of dimensions n,q
68
+
69
+ The two main equations of RNNs being.
70
+
71
+ (1) to update the hidden state,
72
+
73
+ ![Update Hidden State Equation](img/update_hidden_state.png)
74
+
75
+ (2) To calculate the output from hidden state and weight w_hq.
76
+
77
+ ![Output RNN Equation](img/output_rnn_eq.png)
78
+
79
+
80
+ The diagram shows this flow, where the hidden state persists updating,
81
+
82
+ ![Overview RNN Flow](img/overview_rnn_flow.png)
83
+
84
+ For the RNN part, implementation file has been added, namely `RNN_Scratch.py` & `RNN.py`, with an example followed from PyTorch.
85
+
86
+ However, as mentioned before, there are problems with RNNs, notably the vanishing and exploding gradients problem.
87
+
88
+
89
+ #### Vanishing and Exploding Gradients of RNNs
90
+
91
+ The main idea of exploding gradients, was highlighted by this video, and its to do with the recurrent weights above 1,
92
+
93
+ ![Exploding Gradient of RNN Example](img/exploding_grad_RNN.png)
94
+
95
+ Such that recursively multiplication of above causes the model have large gradients, moving it out of a the gradient slope, and possibly ruining any current progress, similar to a large learning rate.
96
+
97
+ The same can be said for weights less than 1, say 0.5, which eventually go to 0, making that specific weights effect on overall model close to 0, and making it hard or backpropagation to learn from it.
98
+
99
+
100
+
101
+ ------------
102
+
103
+ ### LSTMs
104
+
105
+ LSTMs are a noticeable solution to this, although I do not see this, the D2L book mentions the calculation of loss, and could aim to help understand how LSTMs aid in this problem.
106
+
107
+ However, I did not bother, as current time constraints on dissertation.
108
+
109
+ The main idea of LSTMs seems confusing at the start and made no sense to me, however with some time, the video to start with and then the book to solidify, it becomes quite trivial and even easier to implement.
110
+
111
+ The solution uses both short term and long-term memory in the form of two rolling values (persistently used through time steps), a internal state (long term) and a hidden state (short term).
112
+
113
+ We need to remember to key things before going forward, activation functions:
114
+
115
+ - Sigmoid: `1/(1+e^-x)`
116
+
117
+ The sigmoid function is used to squish the values between 0 - 1, and is used to gate the information flow.
118
+ It generates a value from 0 to 1, and we will explain what this is used for (filtering).
119
+
120
+ ![Sigmoid Activation](img/sigmoid.png)
121
+
122
+ - Tanh: `(e^x - e^-x)/(e^x + e^-x)`
123
+
124
+ The tanh function is used to squish the values between -1 - 1, and is used to gate the information flow.
125
+ (Value generation)
126
+
127
+ ![Tanh Activation](img/tanh_activation.png)
128
+
129
+
130
+ LTSMs are very simple, for instance, we have three gates and I will explain this all, here is an overview of the LSTM cell, with some annatations.
131
+
132
+ ![LSTM Architecture](img/LSTM_cell_arch.png)
133
+
134
+ There are 3 different gates, which are the forget gate, input gate, and output gate.
135
+
136
+ The gates can be calculated to input the internal state, with the following equations:
137
+
138
+ ![Dimension of LTSM](img/dimensions_lstm.png)
139
+
140
+ The forget gate, states how much of the current long-term memory is remembered, as a percentage of 0-1, which is multiplied by the previous long term memory.
141
+
142
+ The input gate takes both a percentage (sin) and value (-1 to 1, tanh) to add to the value.
143
+ The output gate is the final output, which is a squished version of the internal state, to be used as the hidden state.
144
+
145
+ Here is another daigram showing the same thing, it's quite simple to use,
146
+
147
+ ![Complete Equation for Input Node](img/arch_rnn.png)
148
+
149
+
150
+ ----
151
+
152
+ We will now implement the LSTM model in the `LSTM_Scratch.py` file, and `LSTM.py` torch implementation.
153
+
154
+
155
+ ### End - Michael Peres
data.py ADDED
File without changes
img/LSTM_cell_arch.png ADDED
img/arch_rnn.png ADDED
img/dimensions_lstm.png ADDED
img/exploding_grad_RNN.png ADDED
img/hidden_state_func.png ADDED
img/output_rnn_eq.png ADDED
img/overview_rnn_flow.png ADDED
img/rnn_arch.png ADDED
img/rnn_op.png ADDED
img/sigmoid.png ADDED
img/tanh_activation.png ADDED
img/update_hidden_state.png ADDED
lstm.py ADDED
File without changes
lstm_scratch.py ADDED
File without changes
rnn.py ADDED
File without changes
rnn_scratch.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # LTSM RNNs - Python
2
+
3
+ import torch
4
+ from torch import nn