欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

Python实现交叉熵损失函数的调用方法

最编程 2024-01-29 20:37:46
...


1、交叉熵损失函数

交叉熵损失函数:

python 调用交叉熵损失 交叉熵损失函数代码_深度学习


在二分类问题中,该函数通常对应:

python 调用交叉熵损失 交叉熵损失函数代码_python_02


其中 表示样本i的标签,正确为1,错误为0. 表示样本i预测为正确的概率。

交叉熵损失函数常被用于分类任务中,由于交叉熵涉及到计算每个类别的概率,所以交叉熵几乎每次都和sigmoid(或softmax)函数一起出现。将神经网络最后一层的输出通过Softmax方法转换为概率分布再与真实类别的 one-hot 形式进行交叉熵的计算。

使用pytorch来计算:

import torch
import torch.nn as nn

batch = 2
num_class = 4

logits = torch.randn(batch,num_class)
target = torch.randint(num_class,size=(batch,))
print(logits)
print('target',target)
ce_loss_fn = torch.nn.CrossEntropyLoss()
ce_loss_fn(logits,target)
tensor([[ 0.0500, -0.6461,  0.9130, -0.0553],
        [-0.3201,  1.5033,  0.2070,  1.8456]])
target tensor([0, 0])
tensor(2.2152)

手算:

m =  nn.Softmax(dim=1)
logits2 = m(logits)
logits2 #tensor([[0.2097, 0.1045, 0.4970, 0.1887],
        #[0.0568, 0.3517, 0.0962, 0.4953]])
(-torch.log(torch.tensor([0.2097]))-torch.log(torch.tensor([0.0568])))/2  #tensor([2.2151])

随即生成2个batch的数据,共有4类标签,对应的标签都是0这个类别,将最后一层的输出进行softmax转换,得到的概率分布为logits2,由于对应的都是类别0,所以取第一个的概率计算即可。

2、负对数似然函数

在概率论中,概率是一个事件发生的可能性,而似然指的是影响概率的未知参数。也就是说,概率是在该未知参数已知的情况下所得到的结果;而似然是该参数未知,我们需要根据观察结果,来估计概率模型的参数。

最大似然估计就是要使得根据目前已经有的数据X,来估计参数,使得在这个参数下出现X的概率值最大。通常似然函数具有如下这个形式:

python 调用交叉熵损失 交叉熵损失函数代码_python_05


通过取对数的方式可以将连续的乘法转换为加法:

python 调用交叉熵损失 交叉熵损失函数代码_深度学习_06


之后就是寻找一个合适的 来使得上面这个式子取到最大值,由于上面的每一项都是小于0的,可以再添加一个负号,就得到了负对数似然函数:

python 调用交叉熵损失 交叉熵损失函数代码_深度学习_08


与上面的交叉熵损失函数相比,这一损失函数在每一项前面少了一个p,但是因为这些都是已经发生的观测样本,所以事实上p都是1.

m =  nn.Softmax(dim=1)
loss = torch.nn.NLLLoss()
input = torch.tensor([[ 0.1076, -1.4376, -0.6307,  0.6451, -1.5122],
        [ 1.5105,  0.7662, -1.7587, -1.4581,  1.1357],
        [-1.4673, -0.5111, -0.0779, -0.7404,  1.4447]], requires_grad=True)
target = torch.tensor([1, 0, 4])

output = loss(m(input), target) #tensor(1.3537)
torch.log(m(input))
#tensor([[-1.2812, -2.8264, -2.0195, -0.7437, -2.9010],
        #[-0.8118, -1.5561, -4.0810, -3.7804, -1.1866],
        #[-3.3349, -2.3787, -1.9455, -2.6080, -0.4229]], grad_fn=<LogBackward>)

按照target给出的索引选取

python 调用交叉熵损失 交叉熵损失函数代码_pytorch_09


推荐阅读