當模型的效能低下時,它無法準確地預測資料。它的主因可能是 overfitting 或 underfitting。如果是 overfitting 的情況,我們可以使用 regularization 來解決模型 overfitting 的問題。
Regularization
Regularization 是一個用來防止 overfitting 的方法。假設有一個 overfitted 模型如左下,我們可以讓 w3
和 w4
很小或趨於 0 來減少 x3
和 x4
對模型的影響,以使得模型變得較平緩,如右下。這就是 regularization 的基本想法。
Gradient descent 會在 cost function 中尋找最小的值。在 cost function 中,如果我們加上 和 ,這會使得 gradient descent 找到的最小值中, 和 會很小或趨於 0。所以,藉由修改 cost function,我們可以在訓練時,減少 和 對模型的影響。
以下的式子是加上 regularization 的 cost function。在後面加上去的式子叫做 regularization term,其中的 λ
叫 regularization parameter。如果將 λ
設為很大的值,如 1010,則所有的 W
都會趨於 0。因此,我們可以藉由調整 λ
來縮小 W
。
Regularized Linear Regression
Regularized linear regression 是 linear regression 的 cost function 加上 regularization term,如下。
Gradient descent 演算法如下。
我們將導數的部分展開後,會變成以下的式子。
簡化後的 wj
的式子會變得如下。可以清楚的看到,我們可以藉由調整 λ
來縮小 wj
。
Regularized Logistic Regression
Logistic regression 的 cost function 加上 regularization term 會變成以下的式子。
Gradient descent 演算法如下。
將導數的部分展開後,會變成以下的式子。看起來和 regularized linear regression 的一模一樣,但是,要注意的是,式子中的 fw,b
是 logistic regression。
結語
Regularization 可以降低 parameters 的小大來解決 overfitting。當 parameters 越大時,受到的處罰會越大,也就是一次會縮小很多。
參考
- Andrew Ng, Machine Learning Specialization, Coursera.