mafei863分享 http://blog.sciencenet.cn/u/mafei863 道法自然,道即合理

博文

[转载]交叉熵分类的标签类型必须换成long

已有 294 次阅读 2021-4-25 17:48 |个人分类:点滴记录|系统分类:科研笔记|文章来源:转载

用pytorch完成字符识别分类任务时,发现loss = lossFunction(out, labels)报错

同样的代码在MNIST数据集上就没有报错,原因是数据载入类型不符合规范


输入labels维度应该为1维,且精度不能是Double,必须换成long


修改后的数据导入代码:


lossFunction = torch.nn.CrossEntropyLoss()

loss = lossFunction(out, labels.long())  # 修改数据精度 




https://wap.sciencenet.cn/blog-538909-1283704.html

上一篇:[转载]hardmax和softmax的区分和优缺点
下一篇:[转载]No module named \'cv2\' (安装cv2)

0

该博文允许注册用户评论 请点击登录 评论 (0 个评论)

数据加载中...
扫一扫,分享此博文

Archiver|手机版|科学网 ( 京ICP备07017567号-12 )

GMT+8, 2021-8-6 06:20

Powered by ScienceNet.cn

Copyright © 2007- 中国科学报社

返回顶部