2.2.5 越来越接近
记住线性回归的训练阶段的任务:我们想要找到一条能够拟合样本数据的直线。也就是说,想要从X和Y的测量值中计算出w,我们可以使用下面的迭代算法完成这个任务:
train()函数会一遍又一遍地遍历相同的样本数据,直到学会如何拟合这些样本数据为止。train()函数的参数是X、Y、一些迭代和另一个称为学习率lr的值(稍后我们将对此进行解释)。训练算法首先将w初始化为0。这个w表示图上的一条线。这条直线不太可能是一个很好的拟合函数,但它是一个开始。
然后,train()函数进入一个循环。每次迭代都是从计算当前损失开始,然后考察另一条直线,即对参数w加上一点增量时得到的一条新的直线。通常将这个增量称为“步长”,但是在代码中我们使用机器学习的术语,称之为学习率,缩写为lr。
我们只是通过将学习率加到w上获得一条新直线。这条新直线的损失比现有直线的损失低吗?如果是,那么w+lr就成为新的当前w,继续下一步的循环。否则,算法将尝试另一条参数为w-lr的直线。同样,如果该直线的损失低于当前直线w的损失,那么将当前的w值更新为w-lr,并继续下一步的循环。
如果w+lr和w-lr都不能产生比当前w更小的损失,那么训练就完成了。此时,我们已经尽可能地逼近了样本数据,并将w返回给调用者。
更具体地说,这个算法是绕着坐标原点上下旋转这条直线,让这条直线的斜率在每次迭代时稍微陡峭一点或稍微平缓一点,并关注直线所对应损失发生的变化。学习率越高,则旋转的速度就越快。想象一个老式的无线电操作员,慢慢地转动一个旋钮,让耳机里的声音变得更为清晰,直到最后让声音变得清晰得不能再清晰为止。
迭代算法有时会陷入无限循环(如果算法不收敛)。计算机科学家证明了我们的这个特定的训练算法不存在这个问题。只要有足够的时间或迭代次数,这个训练算法总是会收敛的。然而,在此之前,算法可能会用完最长的迭代时间或最大的迭代次数,此时train()函数会放弃迭代计算并进行异常退出。
我知道你很想运行这段代码。那么就开始吧。