python tips003 ——DataLoader的collate_fn参数使用详解
2022/1/3 11:07:24
本文主要是介绍python tips003 ——DataLoader的collate_fn参数使用详解,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
背景
最近在看sentences-transformers
的源码,在有一个模块发现了dataloader.collate_fn
,当时没搞懂是什么意思,后来查了一下,感觉还是很有意思的,因此来分享一下。
dataloader
dataloader肯定都是知道的,就是为数据提供一个迭代器。
基本工作机制:
在dataloader按照batch进行取数据的时候, 是取出大小等同于batch size的index列表,然后将列表列表中的index输入到dataset的getitem()函数中,取出该index对应的数据,最后, 对每个index对应的数据进行堆叠,就形成了一个batch的数据。
完整参数列表
DataLoader完整的参数表如下:
class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
- shuffle:设置为True的时候,每个世代都会打乱数据集。
- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能。
- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留。
collate_fn作用
在最后一步堆叠的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠. 如: multi-hot类型的数据, 序列数据。在使用这些数据时, 通常需要先进行长度上的补齐, 再进行堆叠. 以现在的流程, 是没有办法加入该操作的。此外, 某些优化方法是要对一个batch的数据进行操作。
collate_fn函数就是手动将抽取出的样本堆叠起来的函数。
案例说明
import torch from torch.utils.data import DataLoader, TensorDataset import numpy as np test = np.arange(11) input = torch.tensor(np.array([test[i:(i + 3)] for i in range(10 - 1)])) target = torch.tensor(np.array([test[i:(i + 1)] for i in range(10 - 1)])) torch_dataset = TensorDataset(input, target) batch = 3 #> input data shape: torch.Size([9, 3]) #> target data shape: torch.Size([9, 1])
需要注意的是上面的input数据shape为(9, 3);target数据shape为(9,1)。我们设置每一次的batch为3
1. 不设置collate_fn参数
my_dataloader = DataLoader( dataset=torch_dataset, batch_size=batch ) for (i, j) in my_dataloader: print('*' * 30) print(i) print(j)
查看上面的结果就可以看到每一批都返回两个结果,一个是input的样本,一个是target的样本。
input样本、target样本的维度和原始保持一致,但是大小尺寸全部为batch。
2. 设置collate_fn参数为lambda x: x
my_dataloader = DataLoader( dataset=torch_dataset, batch_size=4, collate_fn=lambda x: x ) for i in my_dataloader: print('*' * 30) print(i)
这个时候每一批都是返回了一个列表,这个列表的大小为3,列表里面的每一个对象就是一个成对的input和target。
如果我们继续想把上面的列表解析成第一个的情况,我们可以这么做:
a = i list((torch.cat([a[i][j].unsqueeze(0) for i in range(len(a))]).unsqueeze(0) for j in range(len(a[0]))))
上面其实是很哇塞的,他其实是什么意思,就是把输出的长度为batch的列表转换为一个矩阵了。看着是挺复杂的,其实就是对list做了数据抽取和合并。非常简单。大概可以有这么个拆解路线:
my_dataloader = DataLoader( dataset=torch_dataset, batch_size=4, collate_fn=lambda x: x, drop_last=True ) for i in my_dataloader: print('*' * 30) print(i) a = i a
然后,查看视频:
视频演示
3. 自定义collate_fn参数
现在结合上面的步骤,我们自定义自己的参数,然后实现默认的效果。大概代码如下:
my_dataloader = DataLoader( dataset=torch_dataset, batch_size=batch, collate_fn=lambda x:( torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))],dim=0) for j in range(len(x[0])) ) ) for i,j in my_dataloader: print('*' * 30) print(i) print(j)
最后
- 后面会逐渐关于python更加冷门的东西,也会写pytorch的更多小细节。主要是用来记录自己的学习过程。将中间的一些比较复杂的东西给简单化。
参考链接
- https://blog.csdn.net/weixin_42028364/article/details/81675021
- https://zhuanlan.zhihu.com/p/361830892
- https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
这篇关于python tips003 ——DataLoader的collate_fn参数使用详解的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-05-08有遇到过吗?同样的规则 Excel 中 比Python 结果大
- 2024-03-30开始python成长之路
- 2024-03-29python optparse
- 2024-03-29python map 函数
- 2024-03-20invalid format specifier python
- 2024-03-18pool.map python
- 2024-03-18threads in python
- 2024-03-14python Ai 应用开发基础训练,字符串,字典,文件
- 2024-03-13id3 algorithm python
- 2024-03-13sum array elements python