MaxViT

  • 论文:MaxViT: Multi-Axis Vision Transformer

  • 发表在ECCV 2022

Motivation

研究发现,如果没有广泛的预训练,ViT在图像识别方面表现不佳。这是由于Transformer具有较强的建模能力,但是缺乏归纳偏置,从而导致过拟合。其中一个有效的解决方法就是控制模型容量并提高其可扩展性,在参数量减少的同时得到性能的增强,如Twins、LocalViT以及Swin Transformer等。这些模型通常重新引入层次结构以弥补非局部性的损失,比如Swin Transformer通过在移位的非重叠窗口上应用自我注意。但在灵活性与可扩展性得到提高的同时,由于这些模型普遍失去了类似于ViT的非局部性,即具有有限的模型容量,导致无法在更大的数据集上扩展(ImageNet-21K、JFT等)。

综上,研究局部与全局相结合的方法来增加模型灵活性是有必要的。然而,如何实现对不同数据量的适应,如何有效结合局部与全局计算的优势成为本文要解决的目标。

本文设计了一种简单而有效的视觉Backbone,称为多轴Transformer(MaxViT),它由Max-SA和卷积组成的重复块分层叠加。

  • MaxViT是一个通用的Transformer结构在每一个块内都可以实现局部与全局之间的空间交互,同时可适应不同分辨率的输入大小。
  • Max-SA通过分解空间轴得到窗口注意力(Block attention)与网格注意力(Grid attention),将传统计算方法的二次复杂度降到线性复杂度
  • MBConv作为自注意力计算的补充,利用其固有的归纳偏差来提升模型的泛化能力,避免陷入过拟合。

Method

本文引入了一种新的注意力模块——多轴自注意力(multi-axis self-attention, MaxSA),将传统的自注意机制分解为窗口注意力(Block attention)与网格注意力(Grid attention)两种稀疏形式,在不损失非局部性的情况下,将普通注意的二次复杂度降低到线性。由于Max-SA的灵活性和可伸缩性,我们可以通过简单地将Max-SA与MBConv在分层体系结构中叠加,从而构建一个称为MaxViT的视觉 Backbone,如图所示。

Attention

本文主要采用预归一化相对自注意力作为MaxViT中的关键算子。

Multi-axis Attention

与局部卷积相比,全局相互作用是自注意力机制的优势之一。然而,直接将注意力应用于整个空间在计算上是不可行的,因为注意力算子需要二次复杂度,为了解决全局自注意力导致的二次复杂度,本文提出了一种多轴注意力的方法,通过分解空间轴得到局部(block attention)与全局(grid attention)两种稀疏形式,具体过程如下:

block attention

\[ \text { Block }:(H, W, C) \rightarrow\left(\frac{H}{P} \times P, \frac{W}{P} \times P, C\right) \rightarrow\left(\frac{H W}{P^{2}}, P^{2}, C\right) \]

grid attention

\[ \text { Grid : }(H, W, C) \rightarrow\left(G \times \frac{H}{G}, G \times \frac{W}{G}, C\right) \rightarrow \underbrace{\left(G^{2}, \frac{H W}{G^{2}}, C\right) \rightarrow\left(\frac{H W}{G^{2}}, G^{2}, C\right)}_{\text {swapaxes(axis1 } 1=-2, \text { axis } 2=-3)} \]

注意:遵循Swin Transformer中窗口设置大小, P=G=7 。本文提出的Max-SA模块可以直接替换Swin注意模块力,具有完全相同的参数和FLOPs数量。并且,它享有全局交互能力,而不需要 masking, padding, or cyclic-shifting,使其更易于实现,比移位窗口方案更可取。

MaxViT block

在多轴注意之前,我们还添加了一个带有挤压激励( SE )模块的MBConv块,正如我们所观察到的,将MBConv与注意力结合使用进一步增加了网络的泛化性和可训练性。在注意力之前使用MBConv层提供了另一个优势,因为深度卷积可以被视为条件位置编码( Conditional Position Coding,CPE ),使得我们的模型没有显式的位置编码层。

源代码

  • 以MaxViT-T为例进行注释,即dim_conv_stem = 64, depth = (2, 2, 5, 2), dim = 64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
from functools import partial

import torch
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce


# helpers
def exists(val):
return val is not None


def default(val, d):
return val if exists(val) else d


def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)


# helper classes
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x):
return self.fn(self.norm(x)) + x


class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)


# MBConv
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)

self.gate = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)

def forward(self, x):
return x * self.gate(x)


class MBConvResidual(nn.Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)

def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x


class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob

def forward(self, x):
device = x.device

if self.prob == 0. or (not self.training):
return x

keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)


def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1

net = nn.Sequential(
# [dim_in, n, n] -> [hidden_dim, n, n]
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
# if stride == 2, [hidden_dim, n, n] -> [hidden_dim, n / 2, n / 2]
# if stride == 1, [hidden_dim, n / 2, n / 2] -> [hidden_dim, n / 2, n / 2]
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
# [hidden_dim, n / 2, n / 2] -> [hidden_dim, n / 2, n / 2]
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
# [hidden_dim, n / 2, n / 2] -> [dim_out, n / 2, n / 2]
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)

# 如果不是第一轮下采样
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)

return net


# attention related classes
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=32,
dropout=0.,
window_size=7
):
super().__init__()
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

self.heads = dim // dim_head
self.scale = dim_head ** -0.5

self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

self.attend = nn.Sequential(
nn.Softmax(dim=-1),
nn.Dropout(dropout)
)

self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)

# relative positional bias
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing='ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)

self.register_buffer('rel_pos_indices', rel_pos_indices, persistent=False)

def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads

# flatten
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')

# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim=-1)

# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h=h), (q, k, v))

# scale
q = q * self.scale

# sim
sim = einsum('b h i d, b h j d -> b h i j', q, k)

# add positional bias
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')

# attention
attn = self.attend(sim)

# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)

# merge heads
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1=window_height, w2=window_width)

# combine heads out
out = self.to_out(out)
return rearrange(out, '(b x y) ... -> b x y ...', x=height, y=width)


class MaxViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head=32,
dim_conv_stem=None,
window_size=7,
mbconv_expansion_rate=4,
mbconv_shrinkage_rate=0.25,
dropout=0.1,
channels=3
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

# convolutional stem
dim_conv_stem = default(dim_conv_stem, dim)

self.conv_stem = nn.Sequential(
# [-1, 3, 224, 224] -> [-1, 64, 112, 112]
nn.Conv2d(channels, dim_conv_stem, 3, stride=2, padding=1),
# [-1, 64, 112, 112] -> [-1, 64, 112, 112]
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding=1)
)

# variables, depth = (2, 2, 5, 2)
num_stages = len(depth)

# dims = (64, 128, 256, 512) 代表后四个层的维度
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
# dims = (64, 64, 128, 256, 512) 五个层的维度
dims = (dim_conv_stem, *dims)
# dim_pairs = ((64, 64), (64, 128), (128, 256), (256, 512))
dim_pairs = tuple(zip(dims[:-1], dims[1:]))

self.layers = nn.ModuleList([])

# shorthand for window size for efficient block - grid like attention
# window_size = 7
w = window_size

# iterate through stages
# (((64, 64), 2),
# ((64, 128), 2),
# ((128, 256), 5),
# ((256, 512), 2))
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim

block = nn.Sequential(
# [-1, stage_dim_in, n, n] -> [-1, layer_dim, n / 2, n / 2] if is_first
# [-1, layer_dim, n / 2, n / 2] -> [-1, layer_dim, n / 2, n / 2] if not is_first
MBConv(
stage_dim_in,
layer_dim,
downsample=is_first,
expansion_rate=mbconv_expansion_rate,
shrinkage_rate=mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1=w, w2=w), # block-like attention
PreNormResidual(layer_dim, Attention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=w)),
PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),

Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1=w, w2=w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=w)),
PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)

self.layers.append(block)

# mlp head out

self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)

def forward(self, x):
x = self.conv_stem(x)

for stage in self.layers:
x = stage(x)

return self.mlp_head(x)



本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!