胡天贵
Python中四个转tensor函数的区别
2023-11-8 08:59
阅读:681

Python中的torch包中包含torch.Tensor(a)、torch.tensor(a)、torch.from_numpy(a)、torch.as_tensor(a)四个转tensor函数。区别在于:

  1. torch.Tensor(a)是类构造函数,转出来的tensor格式数据dtype是全局默认dtype(一般为torch.float32),全局默认类型可以通过torch.get_default_dtype()函数来查询;而其它三个都是工厂函数,转出来的tensor格式数据dtype是根据输入a的dtype来推断。

因此:torch.Tensor(a)与torch.tensor(a, dtype=torch.float32)几乎一致。


2. torch.Tensor(a)、torch.tensor(a)是深拷贝,会创造一个新的内存空间,不共享内存,因此a改变时,torch.Tensor(a)、torch.tensor(a)不会改变;而torch.from_numpy(a)、torch.as_tensor(a)不会创造新的内存空间,因此a改变时,torch.from_numpy(a)、torch.as_tensor(a)也会发生改变。


3. torch.from_numpy(a)、torch.as_tensor(a)对比:torch.from_numpy(a)的输入a只能是ndarray格式,并输出一个与a的dtype、device都一样的tensor数据;torch.as_tensor(a)适用性更广,它的输入a可以是非ndarray格式,同时还可以改变dtype和device。torch.as_tensor(a)当a是ndarray格式,且dtype和device都默认时等同于torch.from_numpy(a),是浅拷贝,而当dtype和device不默认时,会创建一个新的内存空间,变为深拷贝。


转载本文请联系原作者获取授权,同时请注明本文来自胡天贵科学网博客。

链接地址:https://wap.sciencenet.cn/blog-3447891-1408870.html?mobile=1

收藏

分享到:

当前推荐数:0
推荐到博客首页
网友评论0 条评论
确定删除指定的回复吗?
确定删除本博文吗?