簡介
傳統機器學習只找出每點數據中的規律,並不考慮時間相關性。例如,今天的天氣,其實很大機會會與明天的天氣相關(假設天氣突變並不常見)。又例如,在一句句子當中,若我們得到前面的句子,我們很容易會猜到下一句的內容(例如 Mary is a _,因為看到 Mary,我們能夠很快估計出,空格應該是 Girl 或者 Lady 等名詞)。而 LSTM (Long Short Term Memory),正正利用了不同的函數,幫助我們找出這些數據的前後之間的關聯性。
LSTM 極簡(忽略數學公式)解說
LSTM 由以下不同部分組成,使它同時擁有「短暫記憶」及「長期記憶」:
- Input Gate:也就是新數據的入口,大約就是正常的記憶入口
- Forget Gate:大約制定忘記多少上一個記憶
- Output Gate:大約制定殘餘去 hidden state 的數據量,並輸出 output state
由此可見,每一粒 LSTM Cell 都會有自己的長期記憶(儲存在 W),而且它也會輸出一些殘餘記憶去 Hidden State,供下一個 cycle 輸入,以至不會完全「忘記」掉。而隨著時間,這些殘餘記憶將會慢慢因為有更多新數據而被「溝淡」。詳細數學公式解釋,可參考此頁。
(有一個動畫演示 LSTM 的網站也十分值得看一下)
實例:股票預測例子
假設今天的股票波幅與將來相若的話,LSTM 就可以用來推算將來的價格。
以下用 TensorFlow 演示:
https://github.com/cmcvista/MLHelloWorld/blob/main/LSTMExample.ipynb
解說:
- 因為我們假定今天的股價會與明天相若,所以,資料排序的時候,採取了這種排法:
輸入 x1, x11, x21 ... x91,得出 x2, x12, x22 ... x92 - 輸入的數據組為 [x1, x11, x21 ... x91],[x2, x12, x22 ... x92],[x3, x13, x23 ... x93] ⋯
故此,同一 LSTM Cell 將會先輸入 x1,然後下一個 cycle 輸入 x2,
這樣才能運用 LSTM 表達數據間的關係。 - TensorFlow 輸入數據中,可以做到一個 batch 入面,數據有時間關係。
假設數據組為 [x1, x2 ... x10],[x11, x12 ... x20],你也可以設定 time step 為 10,
但是次例子中,我們並沒有用到這個功能,所以設置了 time step 為 1。
這個就是train_X_reshaped = np.reshape(train_X, (len(train_X), 1, len(train_X[0])))
的原因。
筆者暫時仍未掌握得很好。但總括而言,用法大致如此。
其實 MATLAB 的做法更加直觀,詳情參見此頁。
參考
[1] - https://www.datacamp.com/community/tutorials/lstm-python-stock-market
[2] - https://www.analyticsvidhya.com/blog/2017/12/fundamentals-of-deep-learning-introduction-to-lstm/
[3] - https://www.tensorflow.org/guide/keras/rnn
[4] - https://www.mathworks.com/help/deeplearning/ug/long-short-term-memory-networks.html
[5] - https://towardsdatascience.com/animated-rnn-lstm-and-gru-ef124d06cf45