Processing math: 100%

[ML Notes] 线性回归:Normal Equation

1. 基本形式

  线性回归的形式为一系列特征的线性组合

f(x)=w0+w1x1+w2x2++wnxn

其中 xi 为特征,wi 为参数。w0 为偏置,w1,w2,,wn 为对应特征的权重。

  令

w=[w0w1w2wn],x=[1x1x2xn]

则式 (1) 可以写为

f(x)=wTx

  对于式 (2) 所示的模型,希望能找到合适的参数 w,使得预测值 f(x) 和真实值 y 间的差别最小。选择 w 的关键在于如何衡量 f(x)y 之间的差别,即预测误差带来的损失。如定义损失为误差的平方和或均方误差,代价函数可以定义为

J(w)=12mi=1(y(i)f(x(i)))2

其中 m 为样例总数,x(i)y(i) 分别为第 i 个样例的特征和标签。

  记

X=[(x(1))T(x(2))T(x(m))T],y=[y(1)y(2)y(m)]

此时模型可以写为

f(X)=[(x(1))Tw(x(2))Tw(x(m))Tw]=Xw

于是

yf(X)=yXw

故式 (3) 也可以写为

J(w)=12||yXw||22=12(yXw)T(yXw)

  基于均方误差最小化对模型参数进行求解的方法称为最小二乘法(least square method)。

2. Normal Equation

  结合以下规则

  • Rn 中的向量 ab,有 aTb=bTa
  • f:RnR 的函数 f(x)=aTx,有 f(x)=a
  • g:RnR 的函数 g(x)=xTAx,其中 An×n 的对称矩阵, 有 g(x)=2Ax

计算式 (4) 关于 w 的梯度,可得

wJ=w12(yXw)T(yXw)=12w(yTyyTXw(Xw)Ty+(Xw)TXw)=12w(yTy2yTXw+wTXTXw)=12(2XTy+2XTXw)=XTXwXTy

wJ=0,有

XTXwXTy=0

XTXw=XTy

XTX 可逆,可以解得

w=(XTX)1XTy