PaddlePaddle动手实现三个版本的CrossEntropy

网友投稿 271 2022-08-27

PaddlePaddle动手实现三个版本的CrossEntropy

今天有人对CrossEntropy有疑惑,我做了一个小实验,这里把我的代码分享了出来:

import paddleimport paddle.nn as nnimport paddle.nn.functional as Fy_true=paddle.to_tensor([[1],[2]])y_pred=paddle.to_tensor([[0.05,0.95,0],[0.1,0.8,0.1]])one_hot = F.one_hot(y_true,num_classes=3)res=paddle.sum(paddle.exp(y_pred),axis=1)res = paddle.reshape(res, [-1, 1])softmax = paddle.exp(y_pred)/reslogsoftmax = paddle.log(softmax)nllloss = -paddle.sum(one_hot.squeeze(1)*logsoftmax)/y_true.shape[0]print(nllloss)logsoftmax = F.log_softmax(y_pred, axis = 1)nllloss = F.nll_loss(logsoftmax, y_true)print(nllloss)nllloss=paddle.nn.functional.cross_entropy(input=y_pred,label=y_true)print(nllloss)

输出:

Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [0.98689497])Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [0.98689508])Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [0.98689508])

应该是符合预期啦

参考文献

​​paddle文档​​​​吃透torch.nn.CrossEntropyLoss()​​

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:git error: RPC failed; curl 92 HTTP/2 stream 0 was not closed cleanly: PROTOCOL_ERROR (err 1)
下一篇:遏制过度营销、野蛮生长,给盲盒市场划一条法律红线!
相关文章

 发表评论

暂时没有评论,来抢沙发吧~