损失函数


损失函数

nn.CrossEntropyLoss()这个损失函数用于多分类问题虽然说的是交叉熵,但是和我理解的交叉熵不一样。
nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合,可以直接使用它来替换网络中的这两个操作。下面我们来看一下计算过程:

首先输入是size是(minibatch,C)。这里的C是类别数。损失函数的计算如下:

png

损失函数中也有权重weight参数设置,若设置权重,则公式为:

png

这里的标签值class,并不参与直接计算,而是作为一个索引,索引对象为实际类别
举个栗子,我们一共有三种类别,批量大小为1(为了好计算),那么输入size为(1,3),具体值为torch.Tensor([[-0.7715, -0.6205,-0.2562]])。
标签值为target = torch.tensor([0]),这里标签值为0,表示属于第0类。loss计算如下:

import torch
import torch.nn as nn
import math
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[-0.7715, -0.6205,-0.2562]])
target = torch.tensor([0])g
output = entroy(input, target)
print(output)

输出:
tensor(1.3447)


评论
  目录