Logistic Regression


Sigmoid函数
input_array = np.array([1,2,3])
exp_array = np.exp(input_array)
# output
[1 2 3]
[ 2.71828183 7.3890561 20.08553692]def sigmoid(z):
"""
Compute the sigmoid of z
Args:
z (ndarray): A scalar, numpy array of any size.
Returns:
g (ndarray): sigmoid(z), with the same shape as z
"""
g = 1/(1+np.exp(-z))
return g
# Generate an array of evenly spaced values between -10 and 10
z_tmp = np.arange(-10,11)
# Use the function implemented above to get the sigmoid values
y = sigmoid(z_tmp)
# Code for pretty printing the two arrays next to each other
np.set_printoptions(precision=3)
print("Input (z), Output (sigmoid(z))")
print(np.c_[z_tmp, y])Decision Boundary - 决策边界


Logistic Loss - 逻辑损失
平方误差损失不适合逻辑损失
成本函数在线性回归中效果很好,自然也可以用于逻辑回归。不过,正如上面的幻灯片所指出的, 现在有一个非线性成分,即 sigmoid 函数: 。让我们在之前的实验示例中尝试一下平方误差成本,现在包含了 sigmoid 函数

Logistic Loss Function - 对数损失函数 - 交叉熵损失函数
逻辑回归使用的损失函数更适合目标为 0 或 1 而不是任何数字的分类任务
是单个数据点的成本,即
是模型的预测值,而 是目标值。
其中函数 𝑔 是 sigmoid 函数。
该损失函数的特点是使用两条不同的曲线。一条适用于目标值为 0 或 ( ) 的情况,另一条适用于目标值为 1 ( ) 的情况。这两条曲线结合在一起,提供了对损失函数有用的行为,即当预测值与目标值相匹配时为零,而当预测值与目标值相差较大时,其值迅速增加。

上述损失函数可以改写,以便更容易实现
当 时,左边项被消除:
当 时,右边项被消除:


Cost Function for Logisic Regression
import numpy as np
X_train = np.array([[0.5, 1.5], [1,1], [1.5, 0.5], [3, 0.5], [2, 2], [1, 2.5]]) #(m,n)
y_train = np.array([0, 0, 0, 1, 1, 1]) 对于逻辑回归,成本函数的形式是
是单个数据点的成本,即
其中 m 是数据集中训练样本的数量
compute_cost_logistic算法会循环计算所有示例,计算每个示例的损失并累加总数
请注意,变量 X 和 y 不是标量值,而是形状分别为 ( ) 和 ( ,) 的矩阵,其中 是特征的数量, 是训练实例的数量
def compute_cost_logistic(X, y, w, b):
"""
Computes cost
Args:
X (ndarray (m,n)): Data, m examples with n features
y (ndarray (m,)) : target values
w (ndarray (n,)) : model parameters
b (scalar) : model parameter
Returns:
cost (scalar): cost
"""
m = X.shape[0]
cost = 0.0
for i in range(m):
z_i = np.dot(X[i],w) + b
f_wb_i = sigmoid(z_i)
cost += -y[i]*np.log(f_wb_i) - (1-y[i])*np.log(1-f_wb_i)
cost = cost / m
return cost
w_tmp = np.array([1,1])
b_tmp = -3
print(compute_cost_logistic(X_train, y_train, w_tmp, b_tmp))output: 0.36686678640551745Last updated
Was this helpful?