线性回归算法原理

    技术2024-01-05  87

    1.线性回归(Linear Regression)

    1.1什么是线性回归

    我们首先用弄清楚什么是线性,什么是非线性。

    (1)线性:两个变量之间的关系是一次函数关系的——图象是直线,叫做线性。 注意:题目的线性是指广义的线性,也就是数据与数据之间的关系。

    (2)非线性:两个变量之间的关系不是一次函数关系的——图象不是直线,叫做非线性。 相信通过以上两个概念大家已经很清楚了,其次我们经常说的回归回归到底是什么意思呢。

    回归:人们在测量事物的时候因为客观条件所限,求得的都是测量值,而不是事物真实的值,为了能够得到真实值,无限次的进行测量,最后通过这些测量数据计算回归到真实值,这就是回归的由来。 通俗的说就是用一个函数去逼近这个真实值,那又有人问了,线性回归不是用来做预测吗?是的,通过大量的数据我们是可以预测到真实值的。

    1.2线性回归要解决什么问题

    对大量的观测数据进行处理,从而得到比较符合事物内部规律的数学表达式。也就是说寻找到数据与数据之间的规律所在,从而就可以模拟出结果,也就是对结果进行预测。解决的就是通过已知的数据得到未知的结果。例如:对房价的预测、判断信用评价、电影票房预估等。

    1.3线性回归的原理公式推导

    首先,我们先定义我们的数据集: 然后我们可以将数据集划分为特征部分和标签部分: 然后,我们线性回归的主要目的就是要让预测值和真实值的误差达到最小,所以损失函数的公式就等于: 下面是对该公式进行化简的一些步骤: 因为对矩阵求导不太熟悉,所以就查了一些在机器学习中常用的矩阵求导公式: 在我们得到了,最后公式后,我们要让他尽可能趋近于0,这样才能将我们的误差降到最小。过程如下: 这样我们得到了,w的公式,可以求得w的值。因为,我们训练集中有一属性的值为1,这一列是为了代替类似于二元线性回归中的截距参数的,所以我们不用特别计算这一参数。

    以下是python代码实现线性回归算法:

    from numpy import * import pandas as pd # 导入数据 def loadDataSet(fileName): numFeat = len(open(fileName).readline().split(' ')) - 1 dataMat = [] labelMat = [] fr = open(fileName) for line in fr.readlines(): lineArr = [] curLine = line.strip().split(' ') for i in range(numFeat): lineArr.append(float(curLine[i])) dataMat.append(lineArr) labelMat.append(float(curLine[-1])) return dataMat, labelMat # 求回归系数 def standRegres(xArr, yArr): xMat = mat(xArr) yMat = mat(yArr).T xTx = xMat.T * xMat if linalg.det(xTx) == 0.0: # 判断行列式是否为0 print("行列式为0") return ws = xTx.I * (xMat.T * yMat) # 也可以用NumPy库的函数求解:ws=linalg.solve(xTx,xMat.T*yMatT) return ws if __name__ == "__main__": xArr, yArr = loadDataSet('Salary_Data.txt') ws = standRegres(xArr, yArr) xMat = mat(xArr) yMat = mat(yArr) # 预测值 yHat = xMat * ws # 计算预测值和真实值得相关性 corrcoef(yHat.T, yMat) # 0.986 # 绘制数据集散点图和最佳拟合直线图 # 创建图像并绘出原始的数据 import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(xMat[:, 1].flatten().A[0], yMat.T[:, 0].flatten().A[0]) # 绘最佳拟合直线,需先要将点按照升序排列 xCopy = xMat.copy() xCopy.sort(0) yHat = xCopy * ws ax.plot(xCopy[:, 1], yHat) plt.show()
    Processed: 0.011, SQL: 9