ai - llm - 单神经网络的入门例子 MSE MAS loss 梯度 平方误差 学习率 lr learn ratio

访问量: 8

refer to: https://www.doubao.com/thread/wb2b77b45e37d1b78

平方误差 MSE     (预测值 - 正确值) ** 2      这个更加好用,导数是2e, 可以一直小步快跑,在e < 1 的时候收敛比 MAE慢。

适合绝大多数场景 。 初期猛冲,后期略慢。 整体快很多  ( 300步迭代)

绝对值误差  MAE    abs (预测值 - 正确值)    这个 导数是 1 (正负), 一直不变。 所以初期散步,后期继续散步。 整体慢不少   ( 400步迭代)

原因:收敛速度取决于 对应的导数,而不是   loss的大小。

MSE:  2e    非常平滑

MAE:  +- 1  

MAE 仅适合 某些 数值不确定的场景

在靠近最优解时, MAE (绝对值误差)会有转折,而MSE则一直平滑

loss 与 梯度的计算:

  4 lr = 0.05
5 x = torch.tensor([1,2,3]) 6 y_true = torch.tensor([5,7, 9]) 7 8 w = torch.tensor([10.0], requires_grad=True) 9 b = torch.tensor([1.0], requires_grad= True)
w -= lr * w.grad
PS C:\workspace\llm\test_pytorch> python .\6_two_layer_nn.py ==== 训练开始,目标: w= 2 b =3 === --- 初始: w = 10.00, b=1.00 step: 0 | w: 10.000, b: 1.000, loss: 238.67 step: 1 | w: 6.667, b: -0.400, loss: 49.72 step: 2 | w: 5.169, b: -0.993, loss: 12.19 ```

第一步:  x = 1 , w = 10, b = 1

x = 1,  y = x * 10 + b = 11

x = 2, y = 20 + 1 = 21

x = 3 , y = 30  + 1 = 31

所以,平方差 =  ( 11 -  5 ) ** 2 =  [ 36,   (21 - 7) ** 2 = 196,  (31- 9 ) ** 2 = 484 ] , 

[ 36, 156 ,484 ] ,   sum 之后  / 3 =  716 / 3 = 238.67

情况2: 使用 绝对值误差的:

```
PS C:\workspace\llm\test_pytorch> python .\6_two_layer_nn.py
==== 训练开始,目标: w= 2 b =3 ===
--- 初始: w =  10.00, b=1.00
 step:  0 | w: 10.000, b:  1.000, loss:    14.00
 step:  1 | w:  9.900, b:  0.950, loss:    13.75
 step:  2 | w:  9.800, b:  0.900, loss:    13.50

x = 1, y = 11,    差值 11 - 5 = 6

x = 2 , y = 21   差值  21 - 7 = 14

x = 3 , y = 31,  差值  31 - 9 = 22

所以 sum 之后 / 3 :  (6 + 14 +22 )/ 3 = 42 / 3 = 14

如何计算得到 step1 的 w: 6.667 ?

PS C:\workspace\llm\test_pytorch> python .\6_two_layer_nn.py
==== 训练开始,目标: w= 2 b =3 ===
--- 初始: w =  10.00, b=1.00
 step:  0 | w: 10.000, b:  1.000, loss:   238.67
 step:  1 | w:  6.667, b: -0.400, loss:    49.72
 step:  2 | w:  5.169, b: -0.993, loss:    12.19

计算过程: 

y_pred = x @ w  + b 

loss = mean ( (y_pred - y_true) ** 2 )

对w 求导:  w.grad =  2 * mean( x * (y_pred - y_true) )

学习率 lr 与 梯度爆炸  https://www.doubao.com/thread/wc3344678a17633f6

订阅/RSS Feed

Subscribe