做垂直类网站,网站数据库有什么用,制作触屏版网站开发,商务网站大全论文《Scaling Local Self-Attention for Parameter Efficient Visual Backbones》
1、作用
HaloNet通过引入Haloing机制和高效的注意力实现#xff0c;在图像识别任务中达到了最先进的准确性。这些模型通过局部自注意力机制#xff0c;有效地捕获像素间的全局交互#xf…论文《Scaling Local Self-Attention for Parameter Efficient Visual Backbones》
1、作用
HaloNet通过引入Haloing机制和高效的注意力实现在图像识别任务中达到了最先进的准确性。这些模型通过局部自注意力机制有效地捕获像素间的全局交互同时通过分块和Haloing策略显著提高了处理速度和内存效率。
2、机制
1、Haloing策略
为了克服传统自注意力的计算和内存限制HaloNet采用了Haloing策略将图像分割成多个块并为每个块扩展一定的Halo区域仅在这些区域内计算自注意力。这种方法减少了计算量同时保持了较大的感受野。
2、多尺度特征层次
HaloNet构建了多尺度特征层次结构通过分层采样和跨尺度的信息流有效捕获不同尺度的图像特征增强了模型对图像中对象大小变化的适应性。
3、高效的自注意力实现
通过改进的自注意力算法包括非中心化的局部注意力和分层自注意力下采样操作HaloNet在保持高准确性的同时提高了训练和推理速度。
3、独特优势
1、参数效率
HaloNet通过局部自注意力机制和Haloing策略大幅度减少了所需的计算量和内存需求实现了与当前最佳卷积模型相当甚至更好的性能但使用更少的参数。
2、适应多尺度
多尺度特征层次结构使得HaloNet能够有效处理不同尺度的对象提高了对复杂视觉任务的适应性和准确性。
3、提升速度和效率
通过优化的自注意力实现HaloNet在不牺牲准确性的前提下实现了比现有技术更快的训练和推理速度使其更适合实际应用。
4、代码
import torch
from torch import nn, einsum
import torch.nn.functional as Ffrom einops import rearrange, repeat# 将设备和数据类型转换为字典格式def to(x):return {device: x.device, dtype: x.dtype}# 确保输入是元组形式
def pair(x):return (x, x) if not isinstance(x, tuple) else x# 在指定维度上扩展张量
def expand_dim(t, dim, k):t t.unsqueeze(dimdim)expand_shape [-1] * len(t.shape)expand_shape[dim] kreturn t.expand(*expand_shape)# 将相对位置编码转换为绝对位置编码
def rel_to_abs(x):b, l, m x.shaper (m 1) // 2col_pad torch.zeros((b, l, 1), **to(x))x torch.cat((x, col_pad), dim2)flat_x rearrange(x, b l c - b (l c))flat_pad torch.zeros((b, m - l), **to(x))flat_x_padded torch.cat((flat_x, flat_pad), dim1)final_x flat_x_padded.reshape(b, l 1, m)final_x final_x[:, :l, -r:]return final_x# 生成一维的相对位置logits
def relative_logits_1d(q, rel_k):b, h, w, _ q.shaper (rel_k.shape[0] 1) // 2logits einsum(b x y d, r d - b x y r, q, rel_k)logits rearrange(logits, b x y r - (b x) y r)logits rel_to_abs(logits)logits logits.reshape(b, h, w, r)logits expand_dim(logits, dim2, kr)return logits# 相对位置嵌入类
class RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height width rel_sizescale dim_head ** -0.5self.block_size block_sizeself.rel_height nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block self.block_sizeq rearrange(q, b (x y) c - b x y c, xblock)rel_logits_w relative_logits_1d(q, self.rel_width)rel_logits_w rearrange(rel_logits_w, b x i y j- b (x y) (i j))q rearrange(q, b x y d - b y x d)rel_logits_h relative_logits_1d(q, self.rel_height)rel_logits_h rearrange(rel_logits_h, b x i y j - b (y x) (j i))return rel_logits_w rel_logits_h# HaloAttention类class HaloAttention(nn.Module):def __init__(self,*,dim,block_size,halo_size,dim_head64,heads8):super().__init__()assert halo_size 0, halo size must be greater than 0self.dim dimself.heads headsself.scale dim_head ** -0.5self.block_size block_sizeself.halo_size halo_sizeinner_dim dim_head * headsself.rel_pos_emb RelPosEmb(block_sizeblock_size,rel_sizeblock_size (halo_size * 2),dim_headdim_head)self.to_q nn.Linear(dim, inner_dim, biasFalse)self.to_kv nn.Linear(dim, inner_dim * 2, biasFalse)self.to_out nn.Linear(inner_dim, dim)def forward(self, x):# 验证输入特征图维度是否符合要求b, c, h, w, block, halo, heads, device *x.shape, self.block_size, self.halo_size, self.heads, x.deviceassert h % block 0 and w % block 0, assert c self.dim, fchannels for input ({c}) does not equal to the correct dimension ({self.dim})q_inp rearrange(x, b c (h p1) (w p2) - (b h w) (p1 p2) c, p1block, p2block)kv_inp F.unfold(x, kernel_sizeblock halo * 2, strideblock, paddinghalo)kv_inp rearrange(kv_inp, b (c j) i - (b i) j c, cc)#生成查询、键、值q self.to_q(q_inp)k, v self.to_kv(kv_inp).chunk(2, dim-1)# 拆分头部q, k, v map(lambda t: rearrange(t, b n (h d) - (b h) n d, hheads), (q, k, v))# 缩放查询向量q * self.scale# 计算注意力sim einsum(b i d, b j d - b i j, q, k)# 添加相对位置偏置sim self.rel_pos_emb(q)# 掩码填充mask torch.ones(1, 1, h, w, devicedevice)mask F.unfold(mask, kernel_sizeblock (halo * 2), strideblock, paddinghalo)mask repeat(mask, () j i - (b i h) () j, bb, hheads)mask mask.bool()max_neg_value -torch.finfo(sim.dtype).maxsim.masked_fill_(mask, max_neg_value)# 注意力机制attn sim.softmax(dim-1)# 聚合out einsum(b i j, b j d - b i d, attn, v)# 合并和组合头部out rearrange(out, (b h) n d - b n (h d), hheads)out self.to_out(out)# 将块合并回原始特征图out rearrange(out, (b h w) (p1 p2) c - b c (h p1) (w p2), bb, h(h // block), w(w // block), p1block,p2block)return out# 输入 N C H W, 输出 N C H W
if __name__ __main__:block HaloAttention(dim512,block_size2,halo_size1, ).cuda()# 创建HaloAttention实例input torch.rand(1, 512, 64, 64).cuda()# 创建随机输入output block(input) # 前向传播print(output.shape)