Monday, June 29, 2020

PyTorch 入門:使用 ResNet9 辨識鳥類品種

前言
機器學習(Machine Learning,下稱 ML)在近年越來越備受關注,因為它透過模仿神經細胞的結構,能在一大堆數據中,找出資料和結果之間的線性 (linear relationship) 和非線性關係 (non-linear relationship)。例如,假設有極大量不同的動物圖片,它能透過這些圖片,找出不同動物的特性(例如顏色、眼晴等),從而「學會」判斷不同動物。

筆者大約在六七年前,仍在大學二年級左右,已經開始見到有人在 Kaggle 中運用 ML 玩簡單的推算遊戲,以及用來訓練股票系統。當年,筆者只接觸過關於 linear regression 和 multi-layer perceptron 這些基本的機器學習概念。但由於 ML 涉及的數學太多、效率也不太高,所以筆者也放棄深究下去。直至 2015 年,DeepMind 運用 deep neural network 學習圍棋玩法,並挑贏了世界冠軍,ML 就開始變成熱門話題。時至今日,ML 發展速度越來越快,幾乎每天都有新的 training model 和 neural network 誕生。

有見及此,筆者為了 catch up 一下科技進步,參加了 Deep Learning with PyTorch 的網上課程。以下的內容,正正就是 course project 的一部分。

簡介
本次實作基於一個鳥類資料庫。資料庫中有 200 種不同鳥類,每種鳥類 5 張彩色圖片,總共 1000 張圖片。目的是建立一個可以認到 200 種鳥類的神經網絡 (Neural Network)。

使用到的技術
卷積神經網絡 Convolutional Neural Network: 這裏是指由原本 224*224*3(224 像素的正方形圖片,三原色 RGB),演變成長闊數值少,但深度變深(512*20*20,512 層,每層 20*20 px)的一種深度神經網絡。背後概念主要是,透過將層數變多,就能使每一層反映的意義更加精準。例如:原本彩色圖片只有 3 層,代表三種顏色。但當變成 64 層,就可以用第一層表示輪廓、第二層表示黑白比例、第三層表示眼晴數目等等。而同時,當像素變少,物件辨識的演算法就能夠更歸納化,不會因為物件向左右移動了些少,就認不到內容。

Residual Connection 殘差連接:這裏是指在 CNN 的結果中,加回輸入值的做法。由於 ML 的重點是降低 loss function 的數值 (loss function's results minimization),透過加回輸入值,loss function 就更能反映出輸入與輸出的關係(減去了因為輸入值大少的影響)。具體邏輯可以看成品中的解說。

Data Augmentation 數據增強:原本是指將現有的圖片改一下,然後放入 training set 去增加取樣率。但這個 project 中運用的,作用並不是增加取樣率,而是由於數據有限,我們不希望神經網絡記錯一些過於特殊的細節,反而沒有找到同品種鳥類的共通點。例如它可能記了綠色背景,卻沒有記住鳥的形狀(也就是擬合過度 overfitting)。這種過程,也稱作歸納化 (generalization)。

Adam Optimizer:由於 ML 的重點在於 loss function's results minimization,如果只用傳統的 Stochastic Gradient Descent,只看斜率變化,就有機會找不到 global maxima,或者會出現跳動。運用這個 optimizer 就可以預測跳動,從而調節 learning rate。具體來講十分複雜,算是 ML 的專業學術範疇,所以不在此詳述。

CUDA 硬體加速:由於顯示卡計算 matrix 加減乘法比普通的處理器快 (general purpose CPU),所以本次使用了 NVIDIA 的加速功能去計算 tensors 的斜率變化。

成品
筆者由於不太熟習 PyTorch,故此只基於 Lecture 5 作出以下小改:
  • 將資料改成鳥類
  • 將 training set, validation set 和 testing set 分開
  • 取消了 data normalization(因考慮到鳥類的顏色比例很重要,而且很麻煩)
  • 加減了一些不必要的 code blocks 和 comments
結果如下:


準確率
由完全隨機開始訓練,結果能去到大約 87% 準確率,有點意想不到呢。
如果有時間的話,筆者會再 post 一下使用 pre-trained model 的結果,我相信應該更好。

參考
https://jovian.ml/forum/t/lecture-5-data-augmentation-regularization-and-resnets/1546