【Pytorch基础】线性模型

共计 4945 个字符,预计需要花费 13 分钟才能阅读完成。

线性模型

一般流程

  1. 准备数据集(训练集,开发集,测试集)
  2. 选择模型(泛化能力,防止过拟合)
  3. 训练模型
  4. 测试模型

例子

学生每周学习时间与期末得分的关系

x(hours) y(points)
1 2
2 4
3 6
4 ?

设计模型

观察数据分布可得应采用线性模型:

$$ \hat y = x * w + b $$

其中 $\hat y$ 为预测值,不妨简化一下模型为:

$$ \hat y = x* w $$

我们的目的就是得到一个尽可能好的 $w$ 值。使模型的预测值越 接近 真实值,因此我们需要一个衡量接近程度的指标 $loss$,可用绝对值或差的平方表示单 g 个样本预测的损失为(Training Loss):

$$loos = (\hat y - y)^2 = (x*w - y)^2 \geq 0$$

这里使用差的平方,其中 $y$ 为真实值。

因此,对于多样本预测的平均损失函数为(Mean Square Error):
$$MSE = \frac{\sum_{i=0}^{n}(\hat y_i - y_i)^2}{n}$$

# 定义模型函数
def forward(x):
    return x * w;

# 定义损失函数
def loss(x, y):
    y_predict = forward(x)
    return (y - y_predict) ** 2

过程模拟

由于不知道 $w$ 的具体值因此我们给它一个随机初始值,假设 $w = 3$

x(hours) y(points) y_predict loss
1 2 3 1
2 4 6 4
3 6 9 9
MSE=14/3

可知本轮预测平均损失为 14/3

为找到最佳权重,可枚举权重值判断损失,损失最小为最佳

# 存放枚举到的权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
    print("w=", w)
    loss_sum = 0 # 损失和
    for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
        y_predict_val = forward(x_val) # 计算预测值
        loss_val = loss(x_val, y_val) # 计算单样本损失
        loss_sum += loss_val # 更新损失和
        print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
    print('MSE=',loss_sum / len(x_data))
    w_list.append(w)
    mse_list.append(loss_sum / len(x_data))

具体实现

import numpy as np
import matplotlib.pyplot as plt

# 准备数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 定义模型函数
def forward(x):
    return x * w;

# 定义损失函数
def loss(x, y):
    y_predict = forward(x)
    return (y - y_predict) ** 2

# 权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
    print("w=", w)
    loss_sum = 0 # 损失和
    for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
        y_predict_val = forward(x_val) # 计算预测值
        loss_val = loss(x_val, y_val) # 计算单样本损失
        loss_sum += loss_val # 更新损失和
        print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
    print('MSE=',loss_sum / len(x_data))
    w_list.append(w)
    mse_list.append(loss_sum / len(x_data))

得到每轮的预测结果

w= 0.0
         1.0 2.0 0.00 4.00
         2.0 4.0 0.00 16.00
         3.0 6.0 0.00 36.00
MSE= 18.666666666666668
w= 0.1
         1.0 2.0 0.10 3.61
         2.0 4.0 0.20 14.44
         3.0 6.0 0.30 32.49
MSE= 16.846666666666668
w= 0.2
         1.0 2.0 0.20 3.24
         2.0 4.0 0.40 12.96
         3.0 6.0 0.60 29.16
MSE= 15.120000000000003
w= 0.30000000000000004
         1.0 2.0 0.30 2.89
         2.0 4.0 0.60 11.56
         3.0 6.0 0.90 26.01
MSE= 13.486666666666665
w= 0.4
         1.0 2.0 0.40 2.56
         2.0 4.0 0.80 10.24
         3.0 6.0 1.20 23.04
MSE= 11.946666666666667
w= 0.5
         1.0 2.0 0.50 2.25
         2.0 4.0 1.00 9.00
         3.0 6.0 1.50 20.25
MSE= 10.5
w= 0.6000000000000001
         1.0 2.0 0.60 1.96
         2.0 4.0 1.20 7.84
         3.0 6.0 1.80 17.64
MSE= 9.146666666666663
w= 0.7000000000000001
         1.0 2.0 0.70 1.69
         2.0 4.0 1.40 6.76
         3.0 6.0 2.10 15.21
MSE= 7.886666666666666
w= 0.8
         1.0 2.0 0.80 1.44
         2.0 4.0 1.60 5.76
         3.0 6.0 2.40 12.96
MSE= 6.719999999999999
w= 0.9
         1.0 2.0 0.90 1.21
         2.0 4.0 1.80 4.84
         3.0 6.0 2.70 10.89
MSE= 5.646666666666666
w= 1.0
         1.0 2.0 1.00 1.00
         2.0 4.0 2.00 4.00
         3.0 6.0 3.00 9.00
MSE= 4.666666666666667
w= 1.1
         1.0 2.0 1.10 0.81
         2.0 4.0 2.20 3.24
         3.0 6.0 3.30 7.29
MSE= 3.779999999999999
w= 1.2000000000000002
         1.0 2.0 1.20 0.64
         2.0 4.0 2.40 2.56
         3.0 6.0 3.60 5.76
MSE= 2.986666666666665
w= 1.3
         1.0 2.0 1.30 0.49
         2.0 4.0 2.60 1.96
         3.0 6.0 3.90 4.41
MSE= 2.2866666666666657
w= 1.4000000000000001
         1.0 2.0 1.40 0.36
         2.0 4.0 2.80 1.44
         3.0 6.0 4.20 3.24
MSE= 1.6799999999999995
w= 1.5
         1.0 2.0 1.50 0.25
         2.0 4.0 3.00 1.00
         3.0 6.0 4.50 2.25
MSE= 1.1666666666666667
w= 1.6
         1.0 2.0 1.60 0.16
         2.0 4.0 3.20 0.64
         3.0 6.0 4.80 1.44
MSE= 0.746666666666666
w= 1.7000000000000002
         1.0 2.0 1.70 0.09
         2.0 4.0 3.40 0.36
         3.0 6.0 5.10 0.81
MSE= 0.4199999999999995
w= 1.8
         1.0 2.0 1.80 0.04
         2.0 4.0 3.60 0.16
         3.0 6.0 5.40 0.36
MSE= 0.1866666666666665
w= 1.9000000000000001
         1.0 2.0 1.90 0.01
         2.0 4.0 3.80 0.04
         3.0 6.0 5.70 0.09
MSE= 0.046666666666666586
w= 2.0
         1.0 2.0 2.00 0.00
         2.0 4.0 4.00 0.00
         3.0 6.0 6.00 0.00
MSE= 0.0
w= 2.1
         1.0 2.0 2.10 0.01
         2.0 4.0 4.20 0.04
         3.0 6.0 6.30 0.09
MSE= 0.046666666666666835
w= 2.2
         1.0 2.0 2.20 0.04
         2.0 4.0 4.40 0.16
         3.0 6.0 6.60 0.36
MSE= 0.18666666666666698
w= 2.3000000000000003
         1.0 2.0 2.30 0.09
         2.0 4.0 4.60 0.36
         3.0 6.0 6.90 0.81
MSE= 0.42000000000000054
w= 2.4000000000000004
         1.0 2.0 2.40 0.16
         2.0 4.0 4.80 0.64
         3.0 6.0 7.20 1.44
MSE= 0.7466666666666679
w= 2.5
         1.0 2.0 2.50 0.25
         2.0 4.0 5.00 1.00
         3.0 6.0 7.50 2.25
MSE= 1.1666666666666667
w= 2.6
         1.0 2.0 2.60 0.36
         2.0 4.0 5.20 1.44
         3.0 6.0 7.80 3.24
MSE= 1.6800000000000008
w= 2.7
         1.0 2.0 2.70 0.49
         2.0 4.0 5.40 1.96
         3.0 6.0 8.10 4.41
MSE= 2.2866666666666693
w= 2.8000000000000003
         1.0 2.0 2.80 0.64
         2.0 4.0 5.60 2.56
         3.0 6.0 8.40 5.76
MSE= 2.986666666666668
w= 2.9000000000000004
         1.0 2.0 2.90 0.81
         2.0 4.0 5.80 3.24
         3.0 6.0 8.70 7.29
MSE= 3.780000000000003
w= 3.0
         1.0 2.0 3.00 1.00
         2.0 4.0 6.00 4.00
         3.0 6.0 9.00 9.00
MSE= 4.666666666666667
w= 3.1
         1.0 2.0 3.10 1.21
         2.0 4.0 6.20 4.84
         3.0 6.0 9.30 10.89
MSE= 5.646666666666668
w= 3.2
         1.0 2.0 3.20 1.44
         2.0 4.0 6.40 5.76
         3.0 6.0 9.60 12.96
MSE= 6.720000000000003
w= 3.3000000000000003
         1.0 2.0 3.30 1.69
         2.0 4.0 6.60 6.76
         3.0 6.0 9.90 15.21
MSE= 7.886666666666668
w= 3.4000000000000004
         1.0 2.0 3.40 1.96
         2.0 4.0 6.80 7.84
         3.0 6.0 10.20 17.64
MSE= 9.14666666666667
w= 3.5
         1.0 2.0 3.50 2.25
         2.0 4.0 7.00 9.00
         3.0 6.0 10.50 20.25
MSE= 10.5
w= 3.6
         1.0 2.0 3.60 2.56
         2.0 4.0 7.20 10.24
         3.0 6.0 10.80 23.04
MSE= 11.94666666666667
w= 3.7
         1.0 2.0 3.70 2.89
         2.0 4.0 7.40 11.56
         3.0 6.0 11.10 26.01
MSE= 13.486666666666673
w= 3.8000000000000003
         1.0 2.0 3.80 3.24
         2.0 4.0 7.60 12.96
         3.0 6.0 11.40 29.16
MSE= 15.120000000000005
w= 3.9000000000000004
         1.0 2.0 3.90 3.61
         2.0 4.0 7.80 14.44
         3.0 6.0 11.70 32.49
MSE= 16.84666666666667
w= 4.0
         1.0 2.0 4.00 4.00
         2.0 4.0 8.00 16.00
         3.0 6.0 12.00 36.00
MSE= 18.666666666666668

画出权重与平均损失的关系图

# 绘图(权重与平均损失的关系)
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('W')
plt.show()

【Pytorch 基础】线性模型

由上图可知,但 $w = 2.0$ 时损失最小,该点也是损失函数图像的最小值。

正文完
 
yhlin
版权声明:本站原创文章,由 yhlin 2023-01-19发表,共计4945字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。