资源经验分享[深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别

[深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别

2020-01-09 | |  52 |   0

原标题:深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别

原文来自:CSDN      原文链接:https://blog.csdn.net/zwqjoy/article/details/96282788


 

  • 函数 Softmax(x)

      是一个 non-linearity, 但它的特殊之处在于它通常是网络中一次操作. 这是因为它接受了一个实数向量并返回一个概率分布.

其定义如下.定义 x 是一个实数的向量(正数或负数都无所谓, 没有限制). 然后, 第i个 Softmax(x) 的组成是 
04.png

输出是一个概率分布: 每个元素都是非负的, 并且所有元素的总和都是1.
 

  • NLLLoss(negative log likelihood loss):最大似然 / log似然代价函数

  • CrossEntropyLoss: 交叉熵损失函数。交叉熵描述了两个概率分布之间的距离,当交叉熵越小说明二者之间越接近。

Cross Entropy(也就是交叉熵)Loss:交叉熵损失函数,通常用于多分类,其中yi是one_hot标签,pi是softmax层的输出结果,交叉熵损失EE定义为:

05.png

Negative Log Liklihood(NLL) Loss:负对数似然损失函数,X是log_softmax()的输出,label是对应的标签位置

06.png

损失函数NLLLoss() 的输入是一个对数概率向量和一个目标标签,并对应位置相乘相加,最后再取负(也就是说,这里的Xlabel,,对于独热码来说,实际上就是取的X中,对应于label中为1的那个x).  它不会为我们计算对数概率,适合最后一层是log_softmax()log_softmax也就是对softmax的输出取对数)的网络. 损失函数 CrossEntropyLoss() 与 NLLLoss()类似, 唯一的不同是它为我们去做 softmax并取对数.可以理解为:

CrossEntropyLoss()=log_softmax() + NLLLoss()

CrossEntropyLoss()=log_softmax() + NLLLoss()
import torch
import torch.nn.functional as F
 
#output是神经网络最后的输出张量,b_y是标签(1hot)
loss1 = torch.nn.CrossEntropyLoss(output,b_y)
loss2 =F.nll_loss(F.log_softmax(output,1),b_y)
#loss1与loss2等效
import torch
import torch.nn as nn
import torch.nn.functional as F
 
data = torch.randn(2, 6)
print('data:', data, 'n')
 
log_soft = F.log_softmax(data, dim=1)
print('log_soft:', log_soft, 'n')
 
target = torch.tensor([1, 2])
 
entropy_out = F.cross_entropy(data, target)
nll_out = F.nll_loss(log_soft, target)
 
print('entropy_out:', entropy_out)
print('nll_out:', nll_out)

07.png

 

免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。

合作及投稿邮箱:E-mail:editor@tusaishared.com

上一篇:[深度学习]自然语言处理 --- ELMo

下一篇:Python---上下文管理器(contextor)

用户评价
全部评价

热门资源

  • Python 爬虫(二)...

    所谓爬虫就是模拟客户端发送网络请求,获取网络响...

  • TensorFlow从1到2...

    原文第四篇中,我们介绍了官方的入门案例MNIST,功...

  • TensorFlow从1到2...

    “回归”这个词,既是Regression算法的名称,也代表...

  • 机器学习中的熵、...

    熵 (entropy) 这一词最初来源于热力学。1948年,克...

  • TensorFlow2.0(10...

    前面的博客中我们说过,在加载数据和预处理数据时...