构建线性回归-理解Loss函数-梯度下降

    技术2022-07-10  77

    from sklearn.datasets import load_boston data = load_boston() X, y = data['data'], data['target'] X[1] y[1] len(X[:, 0]) len(y) %matplotlib inline import matplotlib.pyplot as plt def draw_rm_and_price(): plt.scatter(X[:, 5], y) draw_rm_and_price() import random def price(rm, k, b): """f(x) = k * x + b""" return k * rm + b from IPython import display X_rm = X[:, 5] k = random.randint(-100, 100) b = random.randint(-100, 100) price_by_random_k_and_b = [price(r, k, b) for r in X_rm] plt.scatter(X[:, 5], y) plt.scatter(X_rm, price_by_random_k_and_b) [1, 1, 1] [2, 2, 2]

    如何衡量每条直线的到底好不好呢?

    loss

    𝑙𝑜𝑠𝑠=1𝑛∑(𝑦𝑖−𝑦𝑖^)2 𝑙𝑜𝑠𝑠=1𝑛∑(𝑦𝑖−(𝑘𝑥𝑖+𝑏𝑖))2 ∂𝑙𝑜𝑠𝑠∂𝑘=−2𝑛∑(𝑦𝑖−(𝑘𝑥𝑖+𝑏𝑖))𝑥𝑖 ∂𝑙𝑜𝑠𝑠∂𝑘=−2𝑛∑(𝑦𝑖−𝑦𝑖^)𝑥𝑖 ∂𝑙𝑜𝑠𝑠∂𝑏=−2𝑛∑(𝑦𝑖−𝑦𝑖^)

    def loss(y, y_hat): # to evaluate the performance return sum((y_i - y_hat_i)**2 for y_i, y_hat_i in zip(list(y), list(y_hat))) / len(list(y))

    First-Method: Random generation: get best k and best b

    trying_times = 2000 ​ min_loss = float('inf') best_k, best_b = None, Nonefor i in range(trying_times): k = random.random() * 200 - 100 b = random.random() * 200 - 100 price_by_random_k_and_b = [price(r, k, b) for r in X_rm] ​ current_loss = loss(y, price_by_random_k_and_b) if current_loss < min_loss: min_loss = current_loss best_k, best_b = k, b print('When time is : {}, get best_k: {} best_b: {}, and the loss is: {}'.format(i, best_k, best_b, min_loss)) 10 ** 0.5 3.1622776601683795 X_rm = X[:, 5] k = 15 b = -68 price_by_random_k_and_b = [price(r, k, b) for r in X_rm] ​ draw_rm_and_price() plt.scatter(X_rm, price_by_random_k_and_b)

    2nd-Method: Direction Adjusting

    trying_times = 2000 ​ min_loss = float('inf') ​ best_k = random.random() * 200 - 100 best_b = random.random() * 200 - 100 ​ direction = [ (+1, -1), # first element: k's change direction, second element: b's change direction (+1, +1), (-1, -1), (-1, +1), ] ​ next_direction = random.choice(direction) ​ scalar = 0.1 ​ update_time = 0for i in range(trying_times): k_direction, b_direction = next_direction current_k, current_b = best_k + k_direction * scalar, best_b + b_direction * scalar price_by_k_and_b = [price(r, current_k, current_b) for r in X_rm] ​ current_loss = loss(y, price_by_k_and_b) if current_loss < min_loss: # performance became better min_loss = current_loss best_k, best_b = current_k, current_b next_direction = next_direction update_time += 1 if update_time % 10 == 0: print('When time is : {}, get best_k: {} best_b: {}, and the loss is: {}'.format(i, best_k, best_b, min_loss)) else: next_direction = random.choice(direction)

    如果我们想得到更快的更新,在更短的时间内获得更好的结果,我们需要一件事情:

    找对改变的方向 如何找对改变的方向呢? 2nd-method: 监督让他变化–> 监督学习

    导数

    def partial_k(x, y, y_hat): n = len(y) ​ gradient = 0 for x_i, y_i, y_hat_i in zip(list(x), list(y), list(y_hat)): gradient += (y_i - y_hat_i) * x_i return -2 / n * gradient ​ ​ def partial_b(x, y, y_hat): n = len(y) ​ gradient = 0 for y_i, y_hat_i in zip(list(y), list(y_hat)): gradient += (y_i - y_hat_i) return -2 / n * gradient from icecream import ic trying_times = 2000 ​ X, y = data['data'], data['target'] ​ min_loss = float('inf') ​ current_k = random.random() * 200 - 100 current_b = random.random() * 200 - 100 ​ learning_rate = 1e-04 ​ ​ update_time = 0for i in range(trying_times): price_by_k_and_b = [price(r, current_k, current_b) for r in X_rm] current_loss = loss(y, price_by_k_and_b)if current_loss < min_loss: # performance became better min_loss = current_loss if i % 50 == 0: print('When time is : {}, get best_k: {} best_b: {}, and the loss is: {}'.format(i, best_k, best_b, min_loss)) ​ k_gradient = partial_k(X_rm, y, price_by_k_and_b) b_gradient = partial_b(X_rm, y, price_by_k_and_b) current_k = current_k + (-1 * k_gradient) * learning_rate ​ current_b = current_b + (-1 * b_gradient) * learning_rate X_rm = X[:, 5] k = 11.431551629413757 b = -49.52403584539048 price_by_random_k_and_b = [price(r, k, b) for r in X_rm] ​ plt.scatter(X[:, 5], y)print(len(X_rm), len(price_by_random_k_and_b)) plt.scatter(X_rm, price_by_random_k_and_b)
    Processed: 0.011, SQL: 9