PyTorch中的交叉熵损失函数详解:nn.CrossEntropyLoss
最编程
2024-02-02 09:00:01
...
nn.CrossEntropyLoss()
是nn.logSoftmax()
和nn.NLLLoss()
的整合,可以直接使用它来替换网络中的这两个操作,这个函数可以用于多分类问题。具体的计算过程可以参考官网的公式或者一下这个链接。
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
https://blog.****.net/geter_CS/article/details/84857220
https://zhuanlan.zhihu.com/p/98785902
1. 函数的参数
- weight(Tensor, optional):如果输入这个参数的话必须是一个1维的tensor,长度为类别数C,每个值对应每一类的权重。
- reduction (string, optional) :指定最终的输出类型,默认为’mean’。
none | 无操作 |
---|---|
mean | 输出的结果求均值 |
sum | 输出的结果求和 |
- 其他参数暂时没有用到,回头用到再补充。
2. 函数的使用方法
- 传入input以及target即可,但是需要注意两者的格式,具体见第3点。
- 调用的时候需要注意,不能直接调用,查看第三点例子。
3. 使用注意事项
- 如果不同类别对应的权重不同,传入的权重参数应该是一个1维的tensor。
- 输入的每一类的置信度得分(
input
)应该是原始的,未经过softmax或者normalized。原因是这个函数会首先对输入的原始得分进行softmax,所以必须保证输入的是每一类的原始得分。不能写成[0.2, 0.36, 0.44]
这种softmax之后的或者[0, 1, 0]
这种one-hot编码。 - 输入的
target
也不能是one-hot标签,直接输入每个例子对应的类别编号就行了(0 < target_value < C-1
),比如产生的结果数为N*C
(N为个数,C为类别数),那么输入的target
必须输入一个长度为N
的一维tensor(指明每个结果属于哪一类,如[1, 3, 0]
,函数内部会自动转化为one-hot标签)。参考https://www.jianshu.com/p/19b461421fe7 - 举个例子。
import torch
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
# 方便理解,此处假设batch_size = 1
x_input = torch.randn(2, 3) # 预测2个对象,每个对象分别属于三个类别分别的概率
# 需要的GT格式为(2)的tensor,其中的值范围必须在0-2(0<value<C-1)之间。
x_target = torch.tensor([0, 2]) # 这里给出两个对象所属的类别标签即可,此处的意思为第一个对象属于第0类,第二个我对象属于第2类
loss = loss_fn(x_input, x_target)
print('loss:\n', loss)
--output:
loss:
tensor(0.5060)
- 注意写的过程中不能直接写
loss = nn.CrossEntropyLoss(x_input, x_target)
,必须要向上面那样先把求交叉熵的库函数定义到一个自己命名的函数再调用,否则会报错。RuntimeError: bool value of Tensor with more than one value is ambiguous
4. 输入维度>2的情况
- 上面的例子展示了输入的tensor的维度为2维
即shape为(N, C)
的计算过程以及写代码中需要注意的点,我们实际用的过程中输入的tensor维度一般大于2,比如一个(B, N, C)
的tensor作为输入,下面介绍一下应该怎么写。 - 首先看官方文档中关于若输入高维tensor的情况介绍。
- 如果我们输入的数据为
(B, N, C)
,分别对应batch_size
、 预测的N个对象、C个类别。根据图中文档的规定,第1维必须是类别数目,所以要先把输入换成(B, C, N)
; - 这是一个3维的tensor,相当于此处图片里面的K=1,图片中input的shape就为
(N, C, d_1)
,可以得到对应的Target的shape应该为(N, d_1)
。我们输入的(B, C, N)
对应的Target的格式就应该为(B, N)
。【很好理解,Target的shape就是input的第0维和第2维】。 - 总结一下,首先把Input的shape调整为
(B, C, N)
后,确保Target的输入为(B, N)
即可。 - 举个例子。
mport torch
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
# 假设batch_size = 2, 预测5个对象,类别C=18。
x_input = torch.randn(2, 5, 18)
x_input = x_input.permute(0, 2, 1) # 此处的意思是第0维的数据换成第1维的,第1维换成第2维,第2维换成第0维。即(B,N,C)--->(B,C,N)
# print('x_input_permute:\n', x_input) # 不理解的话可以打印出permute前后的shape进行观察
# 根据前面对target的分析,需要的GT格式为(2, 5),其中的值范围必须在0-17之间。
x_target = torch.tensor([[1, 2, 17, 5, 0],
[3, 15, 7, 10, 8]])
loss = loss_fn(x_input, x_target)
# 或者如下
# loss = loss_fn(x_input.permute(0, 2, 1), x_target)
print('loss:\n', loss)
--output:
loss:
tensor(3.5132)
推荐阅读
-
交叉熵损失函数的推导(逻辑回归)
-
【python实现卷积神经网络】损失函数的定义(均方误差损失、交叉熵损失)
-
玩转神经网络:自定义损失函数、交叉熵与softmax的优化技巧
-
深度学习入门5:交叉熵损失函数、MSE与CTC在序列问题中的应用,以及Balanced L1 Loss在目标检测中的妙用
-
入门机器学习:监督学习中的交叉熵损失函数详解
-
交叉熵损失的计算步骤详解
-
深入理解交叉熵损失函数的原理
-
揭秘交叉熵损失函数的运作原理,让你轻松理解
-
深度学习新手必看:交叉熵损失函数、MSE、CTC损失与序列问题的解决方案,以及Balanced L1 Loss在目标检测中的应用
-
搞懂交叉熵损失函数与平方损失的差异:你真的明白了吗?