张伟
Pytorch中DataLoader与DataSet联合使用
2020-10-28 16:24
阅读:3549

DataLoader与DataSet关系

       DataLoader是Pytorch用来加载数据的一个类,其实就是一个迭代器,而迭代的数据从哪来?就需要用到DataSet了。
       DataSet就是用来封装数据的类,主要用来对数据进行相关的自定义操作(比如图片的裁剪、标签的定义等),通过__getitem__函数返回所需要的数据。

DataSet类介绍

一般来说,需要重新定义一个新的类来继承DataSet类,然后再通过DataLoader来加载器数据。
继承DataSet类一般需要重写其中的__getitem__函数,该函数用于返回第index个数据。其中常常也会重写__len__函数,用于返回整个数据集的大小。
举例:

class MyData(Dataset):

    def __init__(self,imag_path):

        self.imag_path = imag_path

        self.imag_path_list = os.listdir(imag_path)


    def __getitem__(self, item):

        imag_name = self.imag_path_list[item]

        imag_item_path = os.path.join(self.imag_path,imag_name)

        imag = Image.open(imag_item_path)

        label = imag_name

        return imag,label   # 返回的第item项的图片以及对应的标签


    def __len__(self):

        return len(self.imag_path_list)

DataLoader类介绍

       DataLoader一般通过torch.utils.data.DataLoader直接调用即可。DataLoader就是对DataSet中的数据进行迭代,通过__getiem__函数来获取DataSet对应数据集中的第item项数据,然后组合成batch,给程序进行训练。

关于DataLoader类参数详解见:http://blog.sciencenet.cn/blog-3428464-1253283.html

torch.utils.data.DataLoader中的参数:

  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (int, optional) – how many samples per batch to load (default: 1).

  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).

  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.

  • batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • collate_fn (callable*, *optional) – merges a list of samples to form a mini-batch.

  • pin_memory (bool, optional) – If True, the data loader will copy tensors into CUDA pinned memory before returning them.

  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)

  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

DataLoader,Sampler和Dataset三者关系

image.png

见博客:https://zhuanlan.zhihu.com/p/76893455

点滴分享,福泽你我!Add oil!

转载本文请联系原作者获取授权,同时请注明本文来自张伟科学网博客。

链接地址:https://wap.sciencenet.cn/blog-3428464-1256110.html?mobile=1

收藏

分享到:

当前推荐数:0
推荐到博客首页
网友评论0 条评论
确定删除指定的回复吗?
确定删除本博文吗?