河北建设部网站,成都网站建设 seo,广告公司前景怎么样,前端开发工具哪个好Pytorch中utils.data 与torchvision简介1 数据处理工具概述2 utils.data简介3 torchvision简介3.1 transforms3.2 ImageFolder1 数据处理工具概述
Pytorch涉及数据处理#xff08;数据装载、数据预处理、数据增强等#xff09;主要工具包及相互关系如下图所示#xff0c;主…
Pytorch中utils.data 与torchvision简介1 数据处理工具概述2 utils.data简介3 torchvision简介3.1 transforms3.2 ImageFolder1 数据处理工具概述
Pytorch涉及数据处理数据装载、数据预处理、数据增强等主要工具包及相互关系如下图所示主要使用torch.utils.data 与 torchvision
torch.utils.data工具包它包括以下三个类 1Dataset是一个抽象类其它数据集需要继承这个类并且覆写其中的两个方法(getitem、len)。 2DataLoader定义一个新的迭代器实现批量batch读取打乱数据shuffle并提供并行加速等功能。 3random_split把数据集随机拆分为给定长度的非重叠新数据集。 4*sampler多种采样函数。
可视化处理工具torchvision:Pytorch的一个视觉处理工具包独立于Pytorch需要另外安装使用pip或conda安装即可,包含四个类
1datasets:提供常用的数据集加载设计上都是继承torch.utils.data.Dataset主要包括MMIST、CIFAR10/100、ImageNet、COCO等。 2models:提供深度学习中各种经典的网络结构以及训练好的模型(如果选择pretrainedTrue)包括AlexNet, VGG系列、ResNet系列、Inception系列等。 3transforms:常用的数据预处理操作主要包括对Tensor及PIL Image对象的操作。 4utils:含两个函数一个是make_grid它能将多张图片拼接在一个网格中另一个是save_img它能将Tensor保存成图片。
2 utils.data简介
utils.data包括Dataset和 DataLoader 。 torch.utils.data.Dataset:为抽象类。自定义数据集需要继承这个类并实现两个函数。一个是__len__另一个是__getitem__前者提供数据的大小(size)后者通过给定索引获取数据和标签。 由于__getitem__一次只能获取一个数据所以通过torch.utils.data.DataLoader来定义一个新的迭代器实现batch读取。
下面通过举例来比较Dataset 和DataLoader
1,导入相关模块
import torch
from torch.utils import data
import numpy as np2定义获取数据集的类该类继承基类Dataset自定义一个数据集及对应标签。
class TestDataset(data.Dataset):#继承Datasetdef __init__(self):self.Datanp.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#一些由2维向量表示的数据集self.Labelnp.asarray([0,1,0,1,2])#这是数据集对应的标签def __getitem__(self, index):#把numpy转换为Tensortxttorch.from_numpy(self.Data[index])labeltorch.tensor(self.Label[index])return txt,label def __len__(self):return len(self.Data)3获取数据集中数据
TestTestDataset()
print(Test[2]) #相当于调用__getitem__(2)
print(Test.__len__())#輸出
#(tensor([2, 1]), tensor(0))
#5上面使用Dataset的方式每次只返回一个样本。如果希望批处理同时还要shuffle和并行加速等操作可选择DataLoader。
data.DataLoader(dataset,batch_size1,shuffleFalse,samplerNone,batch_samplerNone,num_workers0,collate_fn,pin_memoryFalse,drop_lastFalse,timeout0,worker_init_fnNone,
)主要参数说明 dataset: 加载的数据集 batch_size: 批大小 shuffle是否将数据打乱 sampler样本抽样 num_workers使用多进程加载的进程数0代表不使用多进程 collate_fn如何将多个样本数据拼接成一个batch一般使用默认的拼接方式即可 pin_memory是否将数据保存在pin memory区pin memory中的数据转到GPU会快一些 drop_lastdataset 中的数据个数可能不是 batch_size的整数倍drop_last为True会将多出来不足一个batch的数据丢弃。
test_loader data.DataLoader(Test,batch_size2,shuffleFalse,num_workers2)
for i,traindata in enumerate(test_loader):print(i:,i)Data,Labeltraindataprint(data:,Data)print(Label:,Label)从这个结果可以看出这是批量读取。我们可以像使用迭代器一样使用它,如对它进行循环操作。不过它不是迭代器我们可以通过iter命令转换为迭代器。
一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下不同目录代表不同类别这种情况比较普遍使用data.Dataset来处理就不很方便。
不过可以使用Pytorch另一种可视化数据处理工具即torchvision就非常方便不但可以自动获取标签还提供很多数据预处理、数据增强等转换函数。
3 torchvision简介
torchvision有4个功能模块
modeldatasetstransforms如何使用transforms对源数据进行预处理、增强等utils
3.1 transforms
transforms提供了对PIL Image对象和Tensor对象的常用操作
1对PIL Image的常见操作如下 Scale/Resize: 调整尺寸长宽比保持不变 CenterCrop、RandomCrop、RandomSizedCrop裁剪图片CenterCrop和RandomCrop在crop时是固定sizeRandomResizedCrop则是random size的crop Pad: 填充 ToTensor: 把一个取值范围是[0,255]的PIL.Image 转换成 Tensor。形状为(H,W,C)的numpy.ndarray转换成形状为[C,H,W]取值范围是[0,1.0]的torch.FloatTensor。 RandomHorizontalFlip:图像随机水平翻转翻转概率为0.5; RandomVerticalFlip: 图像随机垂直翻转; ColorJitter: 修改亮度、对比度和饱和度。
2对Tensor的常见操作如下 Normalize: 标准化即减均值除以标准差 ToPILImage:将Tensor转为PIL Image。
如果要对数据集进行多个操作可通过Compose将这些操作像管道一样拼接起来类似于nn.Sequential。以下为示例代码
transforms.Compose([#将给定的 PIL.Image 进行中心切割得到给定的 size#size 可以是 tuple(target_height, target_width)。#size 也可以是一个 Integer在这种情况下切出来的图片形状是正方形。 transforms.CenterCrop(10),#切割中心点的位置随机选取transforms.RandomCrop(20, padding0),#把一个取值范围是 [0, 255] 的 PIL.Image 或者 shape 为 (H, W, C) 的 numpy.ndarray#转换为形状为 (C, H, W)取值范围是 [0, 1] 的 torch.FloatTensortransforms.ToTensor(),#规范化到[-1,1]transforms.Normalize(mean (0.5, 0.5, 0.5), std (0.5, 0.5, 0.5))
])3.2 ImageFolder
当文件依据标签处于不同文件下时如
可以利用 torchvision.datasets.ImageFolder 来直接构造出 dataset代码如下
loader datasets.ImageFolder(path)
loader data.DataLoader(dataset)ImageFolder 会将目录中的文件夹名自动转化成序列那么DataLoader载入时标签自动就是整数序列了。 下面我们利用ImageFolder读取不同目录下图片数据然后使用transorms进行图像预处理预处理有多个我们用compose把这些操作拼接在一起。然后使用DataLoader加载。
对处理后的数据用torchvision.utils中的save_image保存为一个png格式文件然后用Image.open打开该png文件详细代码如下
from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
%matplotlib inlinemy_transtransforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor()
])
train_data datasets.ImageFolder(./data/torchvision_data, transformmy_trans)
train_loader data.DataLoader(train_data,batch_size8,shuffleTrue,)for i_batch, img in enumerate(train_loader):if i_batch 0:print(img[1])fig plt.figure()grid utils.make_grid(img[0])plt.imshow(grid.numpy().transpose((1, 2, 0)))plt.show()utils.save_image(grid,test01.png)break其他功能模块待更新 参考python深度学习-基于pytorch