Python中的torch包中包含torch.Tensor(a)、torch.tensor(a)、torch.from_numpy(a)、torch.as_tensor(a)四个转tensor函数。区别在于:
torch.Tensor(a)是类构造函数,转出来的tensor格式数据dtype是全局默认dtype(一般为torch.float32),全局默认类型可以通过torch.get_default_dtype()函数来查询;而其它三个都是工厂函数,转出来的tensor格式数据dtype是根据输入a的dtype来推断。
因此:torch.Tensor(a)与torch.tensor(a, dtype=torch.float32)几乎一致。
2. torch.Tensor(a)、torch.tensor(a)是深拷贝,会创造一个新的内存空间,不共享内存,因此a改变时,torch.Tensor(a)、torch.tensor(a)不会改变;而torch.from_numpy(a)、torch.as_tensor(a)不会创造新的内存空间,因此a改变时,torch.from_numpy(a)、torch.as_tensor(a)也会发生改变。
3. torch.from_numpy(a)、torch.as_tensor(a)对比:torch.from_numpy(a)的输入a只能是ndarray格式,并输出一个与a的dtype、device都一样的tensor数据;torch.as_tensor(a)适用性更广,它的输入a可以是非ndarray格式,同时还可以改变dtype和device。torch.as_tensor(a)当a是ndarray格式,且dtype和device都默认时等同于torch.from_numpy(a),是浅拷贝,而当dtype和device不默认时,会创建一个新的内存空间,变为深拷贝。
转载本文请联系原作者获取授权,同时请注明本文来自胡天贵科学网博客。
链接地址:https://wap.sciencenet.cn/blog-3447891-1408870.html?mobile=1
收藏