请稍等 ...
×

采纳答案成功!

向帮助你的同学说点啥吧!感谢那些助人为乐的人

正在回答

1回答

首先,这是一个分类任务,也就是分成数字0-9十类, 那么对于一张图片其输出也就有十种可能。分类网络的任务就是计算这张图是十类中的每一类的概率,然后再选出最大概率的索引值表示类别。比如一张图片的经过网络的输出结果是:[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0],  索引值为3的时候概率最大,为0.7, 那么网络就会预测这张图是3

 至 _,pred = output.max(1) ,  就是取最大索引的操作。1实际上是表示取第一维度上的最大索引, 为什么还要有个第一维度呢?是因为我们是按batch_size进行计算的,假设batch_size = 3(也就是对三张图片分类,那就有三个类似[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0]的结果):

[

[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0],

[0.1, 0.9, 0, 0, 0, 0, 0, 0, 0, 0],

[1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0],

]

 _,pred = output.max(1)  的结果是 [3, 1, 0],其意思就是对每一张图片的预测结果求最大索引,给定维度1就是对这个batch_size里面的每一个[1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0]求最大索引


2 回复 有任何疑惑可以回复我~
  • 那个逗号前的'_'是做什么的呀
    回复 有任何疑惑可以回复我~ 2021-01-24 21:58:59
  • _    作为变量名,表示无关紧要的变量。一般不会用到的变量可以用这个来表示
    回复 有任何疑惑可以回复我~ 2021-01-24 22:14:39
问题已解决,确定采纳
还有疑问,暂不采纳
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号