东莞沙田网站建设,网站制作主题,糖果网站建设策划书,怎样建设学校网站首页对于初学者#xff0c;NLP中最烦人的问题之一就数据集的构建问题#xff0c;处理不好就会引起shape问题#xff08;各种由于shape错乱导致的问题#xff09;。这里给出一个模版#xff0c;大家可根据这个模版来构建。
torch.utils.data是PyTorch中用于数据加载和预处理的…对于初学者NLP中最烦人的问题之一就数据集的构建问题处理不好就会引起shape问题各种由于shape错乱导致的问题。这里给出一个模版大家可根据这个模版来构建。
torch.utils.data是PyTorch中用于数据加载和预处理的模块。其中包括Dataset和DataLoader两个类它们通常结合使用来加载和处理数据。
一、Dataset torch.utils.data.Dataset是一个抽象类用于表示数据集。它需要用户自己实现两个方法__ len__ 和__getitem__。其中__len__方法返回数据集的大小__getitem__方法用于根据给定的索引返回一个数据样本。
以下是一个简单的示例展示了如何定义一个数据集
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Datasetclass MyDataset(Dataset):def __init__(self, texts, labels):self.texts textsself.labels labelsdef __len__(self):return len(self.labels)def __getitem__(self, idx):texts self.texts[idx]labels self.labels[idx]return texts, labels在这个示例中MyDataset继承了torch.utils.data.Dataset类并实现了__len__和__getitem__方法。__len__方法返回数据集的大小这里使用了Python内置函数len。__getitem__方法根据给定的索引返回一个数据样本这里返回的是数据列表中对应的元素。
二、DataLoader
torch.utils.data.DataLoader是PyTorch中一个重要的类用于高效加载数据集。它可以处理数据的批次化、打乱顺序、多线程数据加载等功能。 以下是一个简单的示例
# 假设我们有以下三个样本分别由不同数量的单词索引组成
text_data [torch.tensor([1, 2, 3, 4], dtypetorch.long), # 样本1torch.tensor([4, 3, 2], dtypetorch.long), # 样本2torch.tensor([1, 2], dtypetorch.long) # 样本3
]# 对应的标签
labels torch.tensor([1, 0, 1], dtypetorch.float)# 创建数据集和数据加载器
my_dataset MyDataset(text_data, labels)
data_loader DataLoader(my_dataset, batch_size2, shuffleTrue, collate_fnlambda x: x)for batch in data_loader:print(batch)代码输出
[(tensor([4, 3, 2]), tensor(0.)), (tensor([1, 2]), tensor(1.))]
[(tensor([1, 2, 3, 4]), tensor(1.))]在这个示例中我们首先创建了一个MyDataset实例my_dataset它包含了一个整数列表。然后我们使用DataLoader类创建了一个数据加载器data_loader它将data_loader作为输入并将数据分成大小为4的批次并对数据进行随机化。最后遍历data_loader并打印出每个批次的数据。
三、DataLoader参数讲解
函数原型 DataLoader(dataset, batch_size1, shuffleFalse, samplerNone, batch_samplerNone, num_workers0, collate_fnNone, pin_memoryFalse, drop_lastFalse, timeout0, worker_init_fnNone, *, prefetch_factor2, persistent_workersFalse) 常用的参数 1.dataset一个数据集对象必须实现__len__和__getitem__方法。 2.batch_size每个batch的大小。 3.shuffle是否对数据进行洗牌随机打乱。 4.sampler一个数据采样器用于对数据进行自定义采样。 5.batch_sampler一个batch采样器用于对batch进行自定义采样。 6.num_workers用于数据加载的子进程数量。默认值为0表示在主进程中加载数据。 7.collate_fn用于将一个batch的数据合并成一个张量或者元组。 8.pin_memory是否将数据存储在pin memory中锁定物理内存用于GPU加速数据传输默认值为False。 9.drop_last如果数据不能完全分成batch是否删除最后一批数据。默认为False。 10.timeout当数据加载器陷入死锁时等待数据准备的最大秒数。默认值为0表示无限等待。 11.worker_init_fn用于每个数据加载器进程的初始化函数。可以用来设置特定的随机种子。 12.multiprocessing_context用于创建数据加载器子进程的上下文。 以上是torch.utils.data.DataLoader中一些常用的参数使用时根据实际情况选择相应的参数组合。
sampler参数详解 sampler是一个用于指定数据集采样方式的类它控制DataLoader如何从数据集中选取样本。PyTorch提供了多种Sampler类例如RandomSampler和SequentialSampler分别用于随机采样和顺序采样。 以下是一个示例
from torch.utils.data.sampler import RandomSamplermy_sampler RandomSampler(my_dataset)my_dataloader data.DataLoader(my_dataset, batch_size4, shuffleFalse, samplermy_sampler)在这个示例中我们使用RandomSampler类来指定随机采样方式然后将其传递给DataLoader的sampler参数。这将覆盖默认的shuffle参数使数据集按照sampler指定的采样方式进行
四、自定义Dataset类
除了使用torchvision.datasets中提供的数据集我们还可以使用torch.utils.data.Dataset类来自定义自己的数据集。自定义数据集需要实现__len__和__getitem__方法。 ●__init__ 用来初始化数据集 ●__len__方法返回数据集中样本的数量 ●__getitem__给定索引值返回该索引值对应的数据它是python built-in方法其主要作用是能让该类可以像list一样通过索引值对数据进行访问 class MyDataset(data.Dataset):def __init__(self, data_path):self.data_list torch.load(data_path)def __len__(self):return len(self.data_list)def __getitem__(self, index):x self.data_list[index][0]y self.data_list[index][1]return x, y在这个示例中MyDataset类继承自torch.utils.data.Dataset类实现了__len__和__getitem__方法。MyDataset类的构造函数接受一个数据路径作为参数数据集被保存为一个由数据-标签对组成的列表。
五、自定义Sampler类
除了使用torch.utils.data.sampler中提供的采样器我们还可以使用Sampler类来自定义自己的采样器。自定义采样器需要实现__iter__和__len__方法。 ●__iter__方法返回一个迭代器用于遍历数据集中的样本索引。 ●__len__方法返回数据集中样本的数量。 以下是一个示例
class MySampler(Sampler):def __init__(self, data_source):self.data_source data_sourcedef __iter__(self):return iter(range(len(self.data_source)))def __len__(self):return len(self.data_source)在这个示例中MySampler类继承自torch.utils.data.sampler.Sampler类实现了__iter__和__len__方法。
六、自定义Transform类
除了使用torchvision.transforms中提供的变换我们还可以使用transforms模块中的Compose类来自定义自己的变换。Compose类将多个变换组合在一起并按照顺序应用它们。
以下是一个示例
class MyTransform(object):def __call__(self, x):x self.crop(x)x self.to_tensor(x)return xdef crop(self, x):# 这里实现裁剪变换# .......return xdef to_tensor(self, x):# 这里实现张量化变换# .......return xmy_transform transforms.Compose([MyTransform()
])# 创建数据集和数据加载器
my_dataset MyDataset(data_path)
my_dataloader DataLoader(my_dataset, batch_size32, shuffleTrue, num_workers4)# 遍历数据集
for batch in my_dataloader:# 在这里处理数据批次pass在这个示例中MyTransform类实现了一个自定义的变换它将裁剪和张量化两个变换组合在一起。transforms.Compose将这个自定义变换组合成一个变换序列并在数据集中的每个样本上应用这个序列。