铜仁做网站,电力建设集团网站,把网页做成软件,网站内部代码优化相关阅读
Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm1001.2014.3001.5482 在Pytorch中#xff0c;squeeze和unsqueeze是Tensor的一个重要方法#xff0c;同时它们也是torch模块中的一个函数#xff0c;它们的语法如下所示。
Tensor.…相关阅读
Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm1001.2014.3001.5482 在Pytorch中squeeze和unsqueeze是Tensor的一个重要方法同时它们也是torch模块中的一个函数它们的语法如下所示。
Tensor.squeeze(dimNone) → Tensor
torch.squeeze(input, dimNone) → Tensorinput (Tensor) – the input tensor.
dim (int or tuple of ints, optional) – if given, the input will be squeezed only in the specified dimensions.Tensor.unsqueeze(dim) → Tensor
torch.unsqueeze(input, dim) → Tensorinput (Tensor) – the input tensor.
dim (int) – the index at which to insert the singleton dimension
一、squeeze squeeze函数或方法返回一个新的张量该张量移除了原张量中大小为1的维度例如输入张量的形状是(A×1×B×C×1×D)使用了squeeze函数或方法后输出张量的形状是(A×B×C×D)。请注意输出张量将与输入张量共享底层存储因此改变一个张量的内容将改变另一个张量的内容。默认情况下squeeze将移除所有尺寸为1的维度如果传递了dim参数则会将dim中的维度展开。dim的范围可以是[-input.dim()-1, input.dim()]其中负数索引表示从后往前数的位置例如-1代表最后一个维度。 可以看下面的例子以更好的理解
import torch# 创建一个形状为 (2, 1, 2, 1, 2) 的张量
x torch.zeros(2, 1, 2, 1, 2)
print(x, x.size(), id(x))# 移除所有大小为1的维度
a torch.squeeze(x) # 等价于 a x.squeeze()
print(a, a.size(), id(a))# 尝试移除第0维度由于第0维度大小不为1因此不改变形状
b torch.squeeze(x, 0) # 等价于 b x.squeeze(0)
print(b, b.size(), id(b))# 移除第1维度第1维度大小为1
c torch.squeeze(x, 1) # 等价于 c x.squeeze(1)
print(c, c.size(), id(c))# 移除第1、第2和第3维度第1和第3维度大小为1第2维度不变
d torch.squeeze(x, (1, 2, 3)) # 等价于 d x.squeeze((1, 2, 3))
print(d, d.size(), id(d))# 验证所有张量共享底层存储空间
print(x.storage().data_ptr() a.storage().data_ptr() b.storage().data_ptr() c.storage().data_ptr() d.storage().data_ptr()) # 共享底层存储空间输出
tensor([[[[[0., 0.]],[[0., 0.]]]],[[[[0., 0.]],[[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 1899057117680tensor([[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]) torch.Size([2, 2, 2]) 1899057158240tensor([[[[[0., 0.]],[[0., 0.]]]],[[[[0., 0.]],[[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 1899737467296tensor([[[[0., 0.]],[[0., 0.]]],[[[0., 0.]],[[0., 0.]]]]) torch.Size([2, 2, 1, 2]) 1899737467376tensor([[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]) torch.Size([2, 2, 2]) 1899737467216
True 二、 unsqueeze unsqueeze函数或方法函数返回一个新的张量该张量在指定维度(dim)插入一个大小为1的维度。使用unsqueeze函数或方法后输入张量的形状会相应增加一个维度。例如输入张量的形状是A×B×C在第1维度使用unsqueeze后输出张量的形状将变为A×1×B×C。请注意输出张量将与输入张量共享底层存储因此改变一个张量的内容将改变另一个张量的内容。dim的范围可以是[-input.dim(), input.dim()-1]其中负数索引表示从后往前数的位置例如-1代表最后一个维度。 可以看下面的例子以更好的理解
import torch# 创建一个形状为 (2, 2, 2) 的张量
x torch.zeros(2, 2, 2)
print(x, x.size(), id(x))# 在第0维度插入单维度
a torch.unsqueeze(x, 0) # 等价于 a x.unsqueeze(0)
print(a, a.size(), id(a))# 在第1维度插入单维度
b torch.unsqueeze(x, 1) # 等价于 b x.unsqueeze(1)
print(b, b.size(), id(b))# 在第2维度插入单维度
c torch.unsqueeze(x, 2) # 等价于 c x.unsqueeze(2)
print(c, c.size(), id(c))# 在第3维度插入单维度
d torch.unsqueeze(x, 3) # 等价于 d x.unsqueeze(3)
print(d, d.size(), id(d))# 验证所有张量共享底层存储空间
print(x.storage().data_ptr() a.storage().data_ptr() b.storage().data_ptr() c.storage().data_ptr() d.storage().data_ptr()) # 共享底层存储空间输出
tensor([[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]) torch.Size([2, 2, 2]) 1509028592032tensor([[[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]]) torch.Size([1, 2, 2, 2]) 1509028632592tensor([[[[0., 0.],[0., 0.]]],[[[0., 0.],[0., 0.]]]]) torch.Size([2, 1, 2, 2]) 1507561225888tensor([[[[0., 0.]],[[0., 0.]]],[[[0., 0.]],[[0., 0.]]]]) torch.Size([2, 2, 1, 2]) 1507561391824tensor([[[[0.],[0.]],[[0.],[0.]]],[[[0.],[0.]],[[0.],[0.]]]]) torch.Size([2, 2, 2, 1]) 1507561391904
True