模型简单理解为1个函数,通过测试集训练出模型,然后对未知值进行预测
bais:偏置 weight:权重
代码验证,鸢尾花数据集为例,使用花瓣长度与宽度,实现简单线性回归
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.linear_model import LinearRegression from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris plt.rcParams['font.family'] = 'SimHei'## 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False ## 用来正常显示负号 #设置小数点后精度为2,默认为8 np.set_printoptions(precision=2) iris=load_iris() #print(iris.data[:,2]) #print(iris.data[:,2].reshape(-1,1)) X,y=iris.data[:,2].reshape(-1,1),iris.data[:,3] X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.25,random_state=0) #LinearRegression的因变量x需要时二维数组 lr=LinearRegression() #lr.fit进行拟合 lr.fit(X_train,y_train) print('权重',lr.coef_) print('截距',lr.intercept_) #进行预测 y_hat=lr.predict(X_test) print('预测值',y_hat) print('实际值',y_test) #help(train_test_split) #画图展示 plt.figure(figsize=(20,8),dpi=80) plt.scatter(X_test,y_test,color='green',marker='D',label='测试集') plt.scatter(X_train,y_train,color='orange',label='训练集') plt.plot(X,lr.predict(X),'r-') plt.legend() plt.show()观察实际值与预测值得差异
plt.figure(figsize=(15,6),dpi=80) plt.plot(y_test,color='r',marker='o',label='真实数据') plt.plot(y_hat,ls='--',color='g',marker='o',label='预测数据') plt.legend() plt.xlabel('序号') plt.ylabel('数据值')有以下四个指标:
MSERMSEMAER²MSE(Mean Squared Error):平均平方误差,是所有样本数据误差(真实值与预测值之差)的平方和取均值
RMSE(Root Mean Squared Error):平均平方误差的平方根,在MSE基础上取平方根
MAEX(Mean Absolute Error):平均绝对值误差,是所有样本数据误差的绝对值和
R²最常用
from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score print("MSE:{}".format(mean_squared_error(y_test,y_hat))) print("RMSE:{}".format(np.sqrt(mean_squared_error(y_test,y_hat)))) print("MAE:{}".format(mean_absolute_error(y_test,y_hat))) # score:求解R^2的值,r2_score与score方法都能求R^2的值,但是其传递的参数不同 print("训练集R^2:{}".format(r2_score(y_train,lr.predict(X_train)))) print("测试集R^2:{}".format(r2_score(y_test,lr.predict(X_test)))) print("训练集R^2:{}".format(lr.score(X_train,y_train))) print("测试集R^2:{}".format(lr.score(X_test,y_test)))