欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

入门机器学习:监督学习中的交叉熵损失函数详解

最编程 2024-02-02 09:21:06
...

交叉熵损失(Cross-Entropy Loss)是一种常用的目标函数,主要用于二分类或多分类问题中衡量预测结果与真实标签之间的差异。它是基于信息论的概念,通过比较两个概率分布的差异来计算损失。

在二分类问题中,假设有两个类别,真实标签可以表示为 y{0,1}y \in \{0,1\},而模型的预测概率为 y[0,1]y \in [0,1]。交叉熵损失通过计算真实标签的对数概率与预测概率的乘积来衡量预测结果与真实标签之间的差异。

公式如下所示:

CrossEntropy=1Ni=1N(yilog(y^i)+(1yi)log(1y^i))CrossEntropy = -\frac{1}{N}\sum_{i=1}^{N}(y_i\log(\hat y_i) + (1-y_i)\log(1-\hat y_i))

其中,N 表示样本数量,yi 表示第 i 个样本的真实标签(0 或 1),y^i\hat y_i 表示第 i 个样本的预测概率。

在多分类问题中,交叉熵损失可以推广为多个类别的情况。假设有 K 个类别,真实标签可以表示为 y{0,1...,,k1}y \in \{0,1...,,k-1\},而模型的预测概率为 y^[0,1]k\hat y \in [0,1]^k。交叉熵损失的公式如下:

CrossEntropy=1Ni=1Nk=0K1yiklog(y^ik)CrossEntropy = -\frac{1}{N}\sum_{i=1}^{N}\sum_{k=0}^{K-1}y_{ik}\log(\hat y_{ik})

其中,N 表示样本数量,yiky_{ik} 表示第 i 个样本属于第 k 个类别的真实标签,y^ik\hat y_{ik} 表示第 i 个样本属于第 k 个类别的预测概率。

以下是 Python 代码示例,用于计算交叉熵损失:

import numpy as np

def cross_entropy_loss(y_true, y_pred):
    N = len(y_true)
    epsilon = 1e-15  # 用于防止取对数时出现无穷大值
    ce_loss = -np.sum(y_true * np.log(y_pred + epsilon)) / N
    return ce_loss

在示例代码中,y_true 是真实标签的数组,y_pred 是预测概率的数组。使用 NumPy 库计算了交叉熵损失,其中添加了一个很小的值 epsilon 以避免在计算对数时出现无穷大的情况。

需要注意的是,在实际应用中,为了数值稳定性和避免过拟合,通常会对交叉熵损失函数添加正则化项或其他改进措施。此外,还可以使用优化算法(例如梯度下降)来最小化交叉熵损失,以获得更好的模型预测结果。