||
主要介绍torchnet中的常用模块方法
github网址:https://github.com/facebookarchive/torchnet
import torchnet as tnt
acc_meter = tnt.meter.ClassErrorMeter(accuracy=True) #计算精度
loss_meter = tnt.meter.AverageValueMeter() #计算存储数据的均值和标准差
(1)acc_meter使用示例
pred=torch.Tensor([[0.5,0.3,0.4],[0.1,0.4,0.8],[0.6,0.8,0.9]]) #预测标签[0,2,2]
y = torch.Tensor([0,1,2])
acc_meter.add(pred, y)
pred1=torch.Tensor([[0.5,0.3,0.4],[0.1,0.4,0.8],[0.6,0.8,0.9]]) #预测标签[0,2,2]
y1 = torch.Tensor([0,1,2])
acc_meter.add(pred1, y1)
pred2=torch.Tensor([[0.5,0.3,0.4],[0.1,0.8,0.2],[0.6,0.8,0.9]]) #预测标签[0,1,2]
y2 = torch.Tensor([0,1,2])
acc_meter.add(pred2, y2)
acc_meter.value() #返回结果是三个预测结果的精度平均值
Out[42]: [77.77777777777779]
(2)loss_meter使用示例
from torchnet import meter
loss_meter = meter.AverageValueMeter()
loss_meter.reset()
for i in range(10):
loss_meter.add(i)
loss_meter.value() #返回存储数据的均值和标准差
Out[11]: (4.5, 3.0276503540974917) #均值,标准差
点滴分享,福泽你我!Add oil!
Archiver|手机版|科学网 ( 京ICP备07017567号-12 )
GMT+8, 2024-11-9 07:11
Powered by ScienceNet.cn
Copyright © 2007- 中国科学报社