Linear regression 是一個資料分析技術,且使用線性函數來預測未知的資料。雖然 linear regression 模型相對簡單,但它是一個成熟的統計技術。
Table of Contents
Simple Linear Regression
Model Function
Linear regression 的 model function 如下。其中 w
和 b
是參數。當我們將變數 x
帶入後,函數 fw,b
會回傳一個預測值 ŷ
。
要注意的是,fw,b
回傳的是一個預測值 ŷ
,而非真正的值 y
,如下圖所示。當 fw,b
預測出來的 ŷ
很接近 y
時,則我們可以說 fw,b
的準確很高。我們會透過調整 w
和 b
來提高 fw,b
的準確率。
Cost Function
Cost function 是用來衡量 fw,b 的準確率。有了 cost function,我們就可以衡量調整後的 w
和 b
是否比原來的好。
Linear regression 的 cost function 是 squared error,如下。
我們最終想要找到一組 w
和 b
,使得 J(w,b)
會是最小。
Gradient Descent
雖然有了 cost function,但是我們還是不知道要怎麼挑一組比目前好的 w
和 b
。因此,我們需要用 gradient descent 來幫助我們挑選一組比目前好的 w
和 b
。
基本上 Gradient descent 的原理如下圖。首先,隨便挑選一組 w
和 b
,所以可能會在拋物線的左邊或是右邊,然後下一組選擇靠谷底方向的 w
和 b
。重複此動作直到 J(w,b)
無法再更小。
下一個問題是要如何知道 cost function 的谷底是在哪個方向呢?我們可以對 cost function J(w,b) 做微分,然後就可以得到目前的斜率。有了斜率,我們就知道要往哪邊移動。重複此動作,一直到斜率為 0 的時候,我們就知道到了谷底了。
以下為 gradient descent 演算法。
其中 cost function 的導數由以下方式計算出來。
Learning Rate
在 gradient descent 中,有一個 learning rate 叫 。由 gradient descent 演算法可以看出,當 較小時,需要較多的次數才能接近谷底,因此效能會較差。反之,則會比較快速地到達底谷。但是,當 太大時,則有可能會無法到達谷底。
Multiple Linear Regression
Model Function
至目前為止,我們介紹的是 simple linear regression,只有一個變數,實用性可能不大。接下來,我們將介紹 Multiple linear regression。
Multiple linear regression 有一個以上的變數,下面是 4 個變數的 model function。
所以,multiple linear regression 的 model function 如下。
向量化的 model function 則如下。
Cost Function
Multiple linear regression 的 cost function 如下。
向量化的 cost function 如下。
Gradient Descent
Multiple linear regression 的 gradient descent 如下。
其中 cost function 的導數由以下方式計算出來。
結語
Linear regression 是一個相較簡單的模型。很適合用來講解模型背後的數學原理。不過,它還是很實用的。本文章中介紹的是 simple linear regression,只有一個變數,實用性可能不大。Multiple linear regression 可以處理比較複雜的預測。
參考
- Andrew Ng, Machine Learning Specialization, Coursera.