加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – 不平衡数据和加权交叉熵

发布时间:2020-12-20 10:33:02 所属栏目:Python 来源:网络整理
导读:我正在尝试用不平衡的数据训练网络.我有A(198个样本),B个(436个样本),C个(710个样本),D个(272个样本),我读过“weighted_cross_entropy_with_logits”但我发现的所有例子都是二进制分类所以我不是很对如何设置这些重量充满信心. 样本总数:1616 A_weight:198
我正在尝试用不平衡的数据训练网络.我有A(198个样本),B个(436个样本),C个(710个样本),D个(272个样本),我读过“weighted_cross_entropy_with_logits”但我发现的所有例子都是二进制分类所以我不是很对如何设置这些重量充满信心.

样本总数:1616

A_weight:198/1616 = 0.12?

如果我理解的话,背后的想法是惩罚少数民族阶级的错误,并且更加积极地评价少数民族中的命中,对吧?

我的代码:

weights = tf.constant([0.12,0.26,0.43,0.17])
cost = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=pred,targets=y,pos_weight=weights))

我已经阅读了this one和其他二进制分类的例子,但仍然不是很清楚.

提前致谢.

解决方法

请注意,weighted_cross_entropy_with_logits是sigmoid_cross_entropy_with_logits的加权变体. Sigmoid交叉熵通常用于二进制分类.是的,它可以处理多个标签,但是S形交叉熵基本上对每个标签做出(二元)决定 – 例如,对于面部识别网,那些(不是互斥的)标签可以是“受试者是否戴眼镜? “,”主题女性?“等

在二进制分类中,每个输出通道对应于二进制(软)决策.因此,加权需要在损失的计算中发生.这就是weighted_cross_entropy_with_logits通过对交叉熵的一个项加权而对另一个加权.

在互斥的多标签分类中,我们使用softmax_cross_entropy_with_logits,其行为不同:每个输出通道对应于候选类的分数.通过比较每个通道的相应输出来做出决定.

因此,在最终决定之前加权是在比较它们之前修改得分的简单问题,通常通过乘以权重.例如,对于三元分类任务,

# your class weights
class_weights = tf.constant([[1.0,2.0,3.0]])
# deduce weights for batch samples based on their true label
weights = tf.reduce_sum(class_weights * onehot_labels,axis=1)
# compute your (unweighted) softmax cross entropy loss
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(onehot_labels,logits)
# apply the weights,relying on broadcasting of the multiplication
weighted_losses = unweighted_losses * weights
# reduce the result to get your final loss
loss = tf.reduce_mean(weighted_losses)

您还可以依靠tf.losses.softmax_cross_entropy来处理最后三个步骤.

在您需要解决数据不平衡的情况下,类权重确实与其列车数据中的频率成反比.将它们归一化以使它们总计为一个或多个类也是有意义的.

请注意,在上文中,我们根据样本的真实标签惩罚了损失.我们还可以通过简单定义来根据估计的标签惩罚损失

weights = class_weights

由于广播魔术,其余的代码不需要改变.

在一般情况下,您需要权重取决于您所犯的错误类型.换句话说,对于每对标签X和Y,您可以选择在真实标签为Y时如何惩罚选择标签X.最终得到一个完整的先前权重矩阵,这导致权重高于满(num_samples,num_classes)张量.这有点超出了你想要的范围,但是知道在上面的代码中只需要改变你的权重张量的定义可能是有用的.

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读