How an LSTM updates its memory
LSTM Gender Prediction
What we are going to see
Today we are going to deep dive into the LSTM circuit. How exactly the calculations work and how LSTM forget the old information and update the new information.
For the sec of clear understanding we are going to go as per 2 sentences
John is a doctor.
Mary is a teacher
So complete sentence became John is a doctor. Mary is a teacher.
Our job is to predict the gender on each word using LSTM
Why I am writing this
One of my junior ask me a question
Heyy I'm reading this blog by colah on LSTM...I have some doubts...
(LSTM Diagram is attached)
So here we have 3 things
1) C(t-1) cell state long term memory
2) H(t-1) info from last iteration
3) Xt current input
Lets say we have 2 sentence
John is a doctor
Mary is a teacher
Now let's say C(t-1) had stored things like
subject=john, gender = male
Current input is Mary...
now model wants to forget gender = male...
So we take f(W(h(t-1),Xt))...
means we applied sigmoid on current input and past iteration info...
which will generate vector.. let's say [0,1,0,0,1]
then we multiply it with Cell state...
and it will forget some things, remember something based on 0 and 1...
now my que is...
how the hell... it calculated what to remove in cell state
by looking at current input
Understanding how it works is a one thing and explaining it effectively is completely different thing.
So today I am going to try to explain this with the help of this blog and same example. I am also in learning phase, so if any there is any mistake please feel free to address in comment section.
Some thermotical stuff before mathematical
☑️ Pre-requisite: I am assuming you have basic understanding of how LSTM work.
In LSTM we have 3 gates:
Forget Gate - What information should be forget/erase from cell state
Input Gate - What new information should we update in input state
Output Gate - What information from current cell state pass to new hidden state
So, just to be clear output gate doesn't play any role in updating cell state. Cell state is affected by
Forget GateandInput Gate.
Does Forget gate and Input gate have idea what's in cell state ❓
➡️ NOT REALLY ❌
Forget gate never looks at cell state. Lets take a look at equation of forget gate
$$f_t = \sigma \left( W_f \cdot [h_{t-1}, x_t] + b_f \right)$$
If we observe equation carefully C_t-1 previous cell state is not present. It means forget gate does not have any idea about where the gender value stored in memory. It is blind for what present in cell state.
Then question remain same "how did it know what to remove by looking at the input❓"
➡️The answer is it didn't know. It's not reading the cell state and reasoning "ah, gender is in there, let me clear it." It's just emitting a vector of gates from the input alone.
Then why does the mask happen to zero out exactly the gender slot?
Because the weights W_f are learned during training.
The read and write for particular cell is agreed while training as this will be a gender cell. It is not mention anywhere its just understanding due to weights calculation and backpropagation. When gender is predicted for "Mary is a teacher" and if output came as male then that loss is back propagated, weights are updated. LSTM learns that gender slot should be cleared when input word Mary a new subject came. The embeddings of input carry information of, if the current input is subject or not. If yes then what is its gender male or female.
So the next time input patterns resemble LSTM forget gate knows I have to forget the gender information and input gate knows I have to write the new information there.
Small Mathematical Example before actual calculation
Lets say Mary word came. In LSTM
0: Forget
1: Keep the information
Assuming 3rd cell is as gender cell.Positive value represent male gender andNegative value represent female gender.
$$C_t = {f_t \odot C_{t-1}} + {i_t \odot \tilde{C}_t}$$
| state | 1 | 2 | 3 (gender) | 4 |
|---|---|---|---|---|
| Old Cell : C_t-1 | 0.6 | -0.3 | +0.9 (male) | 0.5 |
| x Forget | 1 | 1 | 0 | 1 |
| = erased | 0.6 | -0.3 | 0 | 0.5 |
| + write | 0 | 0 | -0.8 | 0 |
| = new cell : C_t | 0.6 | -0.3 | -0.8 (Female) | 0.5 |
So here you can see forget gave value 0 for 3rd cell and the previous value get erased from cell state.
Later at input gate:
i_t : provide at which cell the value should be updated and with what
C~_t : proposed value to update
Mathematical stuff for complete sentence
Concerted input z
$$z = \left[ h_{\text{prev}},(3\ \text{dims}) \middle| x_{\text{embedding}},(3\ \text{dims}) \right]$$
So, h has h1, h2, h3 and x_embeddins are x1, x2, x3 for input word
Example:
Each gate has one weight row per memory slot. column 3 controls the gender slot. Therefore h3 will represent gender value upto the current word. In this example h_3 is 0.58 it means the current statement has male subject so 'He' can be used. if value is <0.5 then gender will be predicted as female so She can be used.
Now let's see weight vectors for respective gates along with bias
W_f : Forget Gate with
Bias: +3
Here x_1 will be -5 as we want to forget gender cell0
0
0
-5
0
0
h1
h2
h3
x1
x2
x3
W_i: Input Gate
Bias: -2
Here x_1 will be 5 as we want to update gender cell0
0
0
5
0
0
h1
h2
h3
x1
x2
x3
W_c: Candidate
Bias: +0
Here x_2 will be5and x_3 will be-5asmale/femalevalue will be decided0
0
0
0
5
-5
h1
h2
h3
x1
x2
x3
W_o: Output Gate
Bias: +20
0
0
0
0
0
h1
h2
h3
x1
x2
x3
For consistency and for better understanding I have generated interactive html file from Claude by giving rough calculations.
Sentence :
John is a doctor. Mary is a teacher.
Now lets see how exactly gender cell value is calculated.
Website Link: Click Here
URL: https://malivinayak.com/LSTM-Gender-Prediction/
THANK YOU 😊
