论文:CoAtNet: Marrying Convolution and Attention for All Data
Sizes
模型
融合convolution和self-attention
对于卷积,我们主要关注 MBConv 块 ,它采用深度卷积来捕获空间交互。
这种选择的一个关键原因是 Transformer 和 MBConv 中的 FFN
模块都采用了“反向瓶颈”(inverted
bottleneck)的设计,首先将输入的通道大小扩展了 4 倍,然后将 4
倍宽的隐藏状态投影回原始状态 通道大小以启用残差连接。
除了倒置瓶颈的相似性之外,我们还注意到深度卷积和自注意力都可以表示为预定义感受野中值的加权和。
具体来说,卷积依赖于一个固定的内核来从局部感受野收集信息。 \[
y_{i}=\sum_{j \in \mathcal{L}(i)} w_{i-j} \odot x_{j} \ (depthwise \
convolution), \ \ \ (1)
\] 其中\(x_{i}, y_{i} \in
\mathbb{R}^{D}\) 分别是位置 \(i\)
的输入和输出,\(\mathcal{L}(i)\) 表示
\(i\) 的局部邻域,例如图像处理中以
\(i\) 为中心的 3x3 网格。
相比之下,self-attention 允许感受野是整个空间位置,并根据\((x_i,
x_j)\) 对之间重新归一化的成对相似度(re-normalized pairwise
similarity)计算权重: \[
y_{i}=\sum_{j \in \mathcal{G}} \underbrace{\frac{\exp \left(x_{i}^{\top}
x_{j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_{i}^{\top}
x_{k}\right)}}_{A_{i, j}} x_{j}\ \ \ \ (self-attention),\ \ \ (2)
\] 其中 \(\mathcal{G}\)
表示全局空间空间。在讨论如何最好地组合它们之前,值得比较它们的相对优势和劣势,这有助于找出我们希望保留的优良特性。
首先,depthwise 卷积核 \(w_{i-j}\)
是一个静态值的输入独立参数,而注意力权重 \(A_{i,j}\)
动态地取决于输入的表示。因此,自注意力更容易捕捉不同空间位置之间复杂的关系交互,这是我们在处理高级概念时最想要的属性。然而,灵活性伴随着更容易过度拟合的风险,尤其是在数据有限的情况下。
其次,注意给定任何位置对(i; j),对应的卷积权重 \(w_{i-j}\) 只关心它们之间的相对位移,即
\(i -j\) ,而不是 \(i\) 或 \(j\)
的具体值。这个属性通常被称为平移等变性,已经发现它可以提高有限大小数据集下的泛化能力。由于使用绝对位置嵌入,标准
Transformer (ViT)
缺少此属性。这部分解释了为什么当数据集不是很大时,ConvNets 通常比
Transformers 更好。
最后,感受野的大小是自注意力和卷积之间最重要的区别之一。一般来说,更大的感受野提供更多的上下文信息,这可能导致更高的模型容量。因此,全局感受野一直是在视觉中使用自注意力的关键动机。然而,一个大的感受野需要更多的计算。在全局注意力的情况下,复杂性是二次方
w.r.t.空间大小,这是应用自注意力模型的基本权衡。
鉴于上述比较,理想模型应该能够结合表 1 中的 3 个理想属性。
可以实现这一点的一个简单的想法是简单地将全局静态卷积核与自适应注意矩阵相加,无论是在
Softmax 归一化之后还是之前,即 \[
y_{i}^{\text {post }}=\sum_{j \in \mathcal{G}}\left(\frac{\exp
\left(x_{i}^{\top} x_{j}\right)}{\sum_{k \in \mathcal{G}} \exp
\left(x_{i}^{\top} x_{k}\right)}+w_{i-j}\right) x_{j} \text { or }
y_{i}^{\text {pre }}=\sum_{j \in \mathcal{G}} \frac{\exp
\left(x_{i}^{\top} x_{j}+w_{i-j}\right)}{\sum_{k \in \mathcal{G}} \exp
\left(x_{i}^{\top} x_{k}+w_{i-k}\right)} x_{j} .\ \ \ (3)
\] 有趣的是,虽然这个想法似乎过于简化,但预标准化版本 \(y^{pre}\)
对应的是相对自注意力的一个特定变体。在这种情况下,注意力权重 \(A_{i,j}\) 由平移等变性的 \(w_{i-j}\) 和输入自适应的 \(x_{i}^{\top} x_{j}\)
共同决定,根据它们的相对大小可以同时享受这两种效果。重要的是,注意到为了启用全局卷积核而不会造成参数爆炸,我们重新加载了
\(w_{i-j}\) 的符号作为标量(即 \(w \in \mathbb{R}^{O(|\mathcal{G}|)}\)
)而不是Eqn(1)中的向量。\(w\)
的标量形式的另一个优点是,对于所有的\((i,j)\) ,检索 \(w_{i-j}\)
显然是通过计算成对点积注意力来包含的,从而产生最小的额外成本(见附录A.1)。考虑到好处,我们将使用Eqn(3)中带有预标准化相对注意力变量的Transformer块作为Co
At Net模型的关键组成部分。
Vertical Layout Design
在找到了一种将卷积和注意力结合的好方法之后,我们接下来考虑如何利用它来堆叠整个网络。
正如我们在上面讨论的,全局上下文具有空间大小的二次复杂度。因此,如果直接套用Eqn(3)中的相对关注度,对于输入的原始图像,由于任意大小的图像中像素数量较多,计算会过于缓慢。因此,为了构建一个在实践中可行的网络,我们主要有三种选择:
A
执行一些下采样以减小空间大小,并在特征图达到可管理级别后使用全局相对注意力。
B 加强局部注意力,将注意力中的全局感受野 \(\mathcal{G}\) 限制到一个局部领域 \(\mathcal{L}\) ,就像卷积中一样
C
将二次Softmax注意力替换为某些线性注意力变量,该变量仅具有空间大小的线性复杂度
我们对选项( C )进行了简单的实验,但没有得到合理的结果。对于选项( B
),我们发现实现本地注意力涉及许多需要密集内存访问的非平凡形状格式化操作。在我们的选择加速器(
TPU
)上,这样的操作被证明是极其缓慢的,这不仅挫败了加速全球注意力的最初目的,也伤害了模型的容量。因此,正如最近的一些工作研究了这个变体一样,我们将把重点放在选项(
A )上,并在我们的实证研究中比较我们的结果 。
对于方案( A ),下采样可以通过( 1 )
ViT中具有激进步幅的卷积主干(例如,步幅16x16)或( 2
)卷积神经网络中具有渐进池化的多级网络来实现。通过这些选择,我们得到了5个变量的搜索空间,并在受控实验中进行了比较。
当使用ViT Stem时,我们直接堆叠带有相对注意力的 \(L\) Transformer块,我们将其表示为 \(ViT_{REL}\) 。
当使用多级布局时,我们模仿卷积神经网络来构建5级(
S0、S1、S2、S3和S4)的网络,空间分辨率从S0逐渐降低到S4。在每个阶段开始时,我们总是将空间大小减少2倍,并增加通道数 (有关详细的下采样实现,请参见附录A.1)。
第一阶段S0是一个简单的2层卷积Stem,S1始终使用带有squeeze-excitation
(SE)操作和MBConv块,因为空间尺寸对于全局注意力来说太大了。从S2到S4,我们要么考虑MBConv,要么考虑Transformer块,约束卷积阶段必须出现在Transformer阶段之前。该约束基于这样的先验,即卷积在处理早期阶段更常见的局部模式方面更好 。这导致了4种变体,其Transformer级数越来越多,C
- C - C - C、C - C - C - T、C - C - T - T和C - T - T -
T,其中C和T分别表示Convolution和Transformer。
为了系统地研究设计选择,我们考虑两个基本方面的泛化能力和模型能力:对于泛化,我们感兴趣的是训练损失和评估精度之间的差距。如果两个模型具有相同的训练损失,那么具有较高评估精度的模型具有更好的泛化能力,因为它可以更好地泛化到看不见的评估数据集。当训练数据量有限时,泛化能力对数据效率尤为重要。对于模型容量,我们衡量了模型对大型训练数据集的拟合能力。当训练数据较为丰富且不存在过拟合问题时,容量较高的模型在经过合理的训练步骤后,最终更好的最终性能。值得注意的是,由于简单地增加模型大小可以导致更高的模型容量,为了进行有意义的比较,我们确保5个变体的模型大小具有可比性。
为了比较模型的泛化能力和模型容量,我们在ImageNet - 1K ( 1.3M )和JFT (
> 300M
)数据集上分别训练了300和3个epoch的混合模型的不同变体,两者都没有任何正则化或增强。两个数据集上的训练损失和评估准确率总结如图1所示。
图1
从ImageNet -
1K的结果中,一个关键的观察是,在泛化能力(即,培训和评价指标之间的差距)方面,我们有
\[
\mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{C} \approx
\mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{T} \geq
\mathrm{C}-\mathrm{C}-\mathrm{T}-\mathrm{T}>\mathrm{C}-\mathrm{T}-\mathrm{T}-\mathrm{T}
\gg \mathrm{VIT}_{\mathrm{REL}}
\]
特别是,\(ViT_{REL}\) 比变种要差得多,我们猜想这与在其积极的下采样Stem中缺乏适当的低级信息处理有关。在多级变体中,总体趋势是模型的卷积级数越多,泛化差距越小。
在模型能力方面,从JFT比较来看,训练结束时的训练指标和评价指标都显示出以下排序:
\[
\mathrm{C}-\mathrm{C}-\mathrm{T}-\mathrm{T} \approx
\mathrm{C}-\mathrm{T}-\mathrm{T}-\mathrm{T}>\mathrm{VIT}_{\mathrm{REL}}>\mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{T}>\mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{C}
\]
重要的是,这表明简单地拥有更多的Transformer块并不一定意味着更高的视觉处理能力 。一方面,\(ViT_{REL}\)
在最初表现较差的情况下,最终赶上了MBConv级数较多的两个变体,体现了Transformer模块的容量优势。另一方面,C
- C - T - T和C - T - T - T的表现明显优于\(ViT_{REL}\) ,说明大步长的ViT
Stem可能丢失了过多的信息,从而限制了模型的容量 。更有趣的是,C-C-T-T≈C-T-T-T表明,对于处理低级信息,像卷积这样的静态局部操作可以像自适应全局注意力机制一样有效,同时节省大量的计算和内存使用。
最后,为了决定C - C - T - T和C - T - T -
T之间的关系,我们进行了另一个可迁移性测试3 - -我们在ImageNet -
1K上对上述两个JFT预训练模型进行了30个epoch的微调,并比较了它们的迁移性能。从表2中可以看出,尽管预训练性能相同,C-C-T-T取得了明显优于C-T-T-T的传输精度。
综合考虑模型的泛化性、模型容量、可移植性和效率等因素,我们采用C
- C - T - T的多级布局方式 。附录A.1中包含了更多的模型细节。
附录
A.1 模型细节
首先,CoAtNet概述如图所示
A.2 超参数
模型搭建
s0-stage
def conv_3x3_bn (in_c, out_c, image_size, downsample=False ): stride = 2 if downsample else 1 layer = nn.Sequential( nn.Conv2d(in_c, out_c, 3 , stride, 1 , bias=False ), nn.BatchNorm2d(out_c), nn.GELU() ) return layer
s1-MBConv
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 class MBConv (nn.Module): def __init__ (self, in_c, out_c, image_size, downsample=False , expansion=4 ): super (MBConv, self).__init__() self.downsample = downsample stride = 2 if downsample else 1 hidden_dim = int (in_c * expansion) if self.downsample: self.downsample_layer = nn.Sequential( nn.MaxPool2d(kernel_size=3 , stride=2 , padding=1 ), nn.Conv2d(in_c, out_c, 1 , 1 , 0 , bias=False ) ) layers = OrderedDict() expand_conv = nn.Sequential( nn.Conv2d(in_c, hidden_dim, 1 , stride, 0 , bias=False ), nn.BatchNorm2d(hidden_dim), nn.GELU(), ) layers.update({"expand_conv" : expand_conv}) dw_conv = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, 3 , 1 , 1 , groups=hidden_dim, bias=False ), nn.BatchNorm2d(hidden_dim), nn.GELU(), ) layers.update({"dw_conv" : dw_conv}) layers.update({"se" : SE(in_c, hidden_dim)}) pro_conv = nn.Sequential( nn.Conv2d(hidden_dim, out_c, 1 , 1 , 0 , bias=False ), nn.BatchNorm2d(out_c) ) layers.update({"pro_conv" : pro_conv}) self.block = nn.Sequential(layers) def forward (self, x ): if self.downsample: return self.downsample_layer(x) + self.block(x) else : return x + self.block(x)
其中,Se模块代码为
class SE (nn.Module): def __init__ (self, in_c, out_c, expansion=0.25 ): super (SE, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1 ) self.fc = nn.Sequential( nn.Linear(out_c, int (in_c * expansion), bias=False ), nn.GELU(), nn.Linear(int (in_c * expansion), out_c, bias=False ), nn.Sigmoid() ) def forward (self, x ): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1 , 1 ) return x * y
s3-TFM_rel
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 class Transformer (nn.Module): def __init__ (self, in_c, out_c, image_size, heads=8 , dim_head=32 , downsample=False , dropout=0. , expansion=4 , norm_layer=nn.LayerNorm ): super (Transformer, self).__init__() self.downsample = downsample hidden_dim = int (in_c * expansion) self.ih, self.iw = image_size if self.downsample: self.pool1 = nn.MaxPool2d(kernel_size=3 , stride=2 , padding=1 ) self.pool2 = nn.MaxPool2d(kernel_size=3 , stride=2 , padding=1 ) self.proj = nn.Conv2d(in_c, out_c, 1 , 1 , 0 , bias=False ) self.attn = Attention(in_c, out_c, image_size, heads, dim_head, dropout) self.ffn = FFN(out_c, hidden_dim) self.norm1 = norm_layer(in_c) self.norm2 = norm_layer(out_c) def forward (self, x ): x1 = self.pool1(x) if self.downsample else x x1 = rearrange(x1, 'b c h w -> b (h w) c' ) x1 = self.attn(self.norm1(x1)) x1 = rearrange(x1, 'b (h w) c -> b c h w' , h=self.ih, w=self.iw) x2 = self.proj((self.pool2(x))) if self.downsample else x x3 = x1 + x2 x4 = rearrange(x3, 'b c h w -> b (h w) c' ) x4 = self.ffn(self.norm2(x4)) x4 = rearrange(x4, 'b (h w) c -> b c h w' , h=self.ih, w=self.iw) out = x3 + x4 return out
attention(相对位置编码)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 class Attention (nn.Module): def __init__ (self, in_c, out_c, image_size, heads=8 , dim_head=32 , dropout=0. ): super (Attention, self).__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == in_c) self.ih, self.iw = image_size if len (image_size) == 2 else (image_size, image_size) self.heads = heads self.scale = dim_head ** -0.5 self.relative_bias_table = nn.Parameter( torch.zeros((2 * self.ih - 1 ) * (2 * self.iw - 1 ), heads) ) coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw))) coords = torch.flatten(torch.stack(coords), 1 ) relative_coords = coords[:, :, None ] - coords[:, None , :] relative_coords[0 ] += self.ih - 1 relative_coords[1 ] += self.iw - 1 relative_coords[0 ] *= 2 * self.iw - 1 relative_coords = rearrange(relative_coords, 'c h w -> h w c' ) relative_index = relative_coords.sum (-1 ).flatten().unsqueeze(1 ) """ PyTorch中定义模型时,self.register_buffer('name', Tensor), 该方法的作用是定义一组参数,该组参数的特别之处在于: 模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值), 但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。 """ self.register_buffer("relative_index" , relative_index) self.attend = nn.Softmax(dim=-1 ) self.qkv = nn.Linear(in_c, inner_dim * 3 , bias=False ) self.proj = nn.Sequential( nn.Linear(inner_dim, out_c), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward (self, x ): qkv = self.qkv(x).chunk(3 , dim=-1 ) q, k, v = map (lambda t: rearrange( t, 'b n (h d) -> b h n d' , h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1 , -2 )) * self.scale relative_bias = self.relative_bias_table.gather( 0 , self.relative_index.repeat(1 , self.heads) ) relative_bias = rearrange( relative_bias, '(h w) c -> 1 c h w' , h=self.ih * self.iw, w=self.ih * self.iw ) dots = dots + relative_bias attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)' ) out = self.proj(out) return out
FFN(MLP)
class FFN (nn.Module): def __init__ (self, dim, hidden_dim, dropout=0. ): super (FFN, self).__init__() self.ffn = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward (self, x ): return self.ffn(x)
CoAtNet
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 class CoAtNet (nn.Module): def __init__ (self, image_size=(224 , 224 ), in_channels: int = 3 , num_blocks: list = [2 , 2 , 3 , 5 , 2 ], channels: list = [64 , 96 , 192 , 384 , 768 ], num_classes: int = 1000 , block_types=['C' , 'C' , 'T' , 'T' ] ): super (CoAtNet, self).__init__() assert len (image_size) == 2 , "image size must be: {H,W}" assert len (channels) == 5 assert len (block_types) == 4 ih, iw = image_size block = {'C' : MBConv, 'T' : Transformer} self.s0 = self._make_layer( conv_3x3_bn, in_channels, channels[0 ], num_blocks[0 ], (ih // 2 , iw // 2 ) ) self.s1 = self._make_layer( block[block_types[0 ]], channels[0 ], channels[1 ], num_blocks[1 ], (ih // 4 , iw // 4 ) ) self.s2 = self._make_layer( block[block_types[1 ]], channels[1 ], channels[2 ], num_blocks[2 ], (ih // 8 , iw // 8 ) ) self.s3 = self._make_layer( block[block_types[2 ]], channels[2 ], channels[3 ], num_blocks[3 ], (ih // 16 , iw // 16 ) ) self.s4 = self._make_layer( block[block_types[3 ]], channels[3 ], channels[4 ], num_blocks[4 ], (ih // 32 , iw // 32 ) ) self.pool = nn.AvgPool2d(ih // 32 , 1 ) self.fc = nn.Linear(channels[-1 ], num_classes, bias=False ) for m in self.modules(): if isinstance (m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out' , nonlinearity='relu' ) elif isinstance (m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): nn.init.constant_(m.weight, 1 ) nn.init.constant_(m.bias, 0 ) def forward (self, x ): x = self.s0(x) x = self.s1(x) x = self.s2(x) x = self.s3(x) x = self.s4(x) x = self.pool(x) x = torch.flatten(x, 1 ) x = self.fc(x) return x def _make_layer (self, block, in_c, out_c, depth, image_size ): layers = nn.ModuleList([]) for i in range (depth): if i == 0 : layers.append(block(in_c, out_c, image_size, downsample=True )) else : layers.append(block(out_c, out_c, image_size, downsample=False )) return nn.Sequential(*layers)