Masked_Autoencoder(MAE)

  • 论文:Masked Autoencoders Are Scalable Vision Learners

模型图

Approach

Masking

MAE使用了ViT的方法,将图片分割成一个个小块,然后在这些小块中随机、均匀地选取一部分保留,剩下的全部遮蔽。作者强调了要遮蔽大量的像素块(约75%),从而减少像素块之间的冗余信息,使整个任务更具有挑战性,迫使模型去学习图像的全局特征而非局部特征,从而获得更优的图像重构能力。

MAE encoder

encode用的就是一个完整的ViT,其输入仅为没有被mask的patch,节省了计算开支,可以通过较小的计算和内存来训练出较大的encoder。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MAE_Encoder(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
num_layer=12,
num_head=3,
mask_ratio=0.75,
) -> None:
super().__init__()

self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))

# 对patch进行shuffle 和 mask
self.shuffle = PatchShuffle(mask_ratio)

# 这里得到一个 (3, dim, patch, patch)
self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)

self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

# ViT的laynorm
self.layer_norm = torch.nn.LayerNorm(emb_dim)

self.init_weight()
# 初始化类别编码和向量编码
def init_weight(self):
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.pos_embedding, std=.02)

def forward(self, img):
patches = self.patchify(img)
patches = rearrange(patches, 'b c h w -> (h w) b c')
patches = patches + self.pos_embedding

patches, forward_indexes, backward_indexes = self.shuffle(patches)

patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, 't b c -> b t c')
features = self.layer_norm(self.transformer(patches))
features = rearrange(features, 'b t c -> t b c')

return features, backward_indexes

MAE decoder

decoder 的输入是整个 tokens 的集合,包含

  • 编码好的可见的 patches
  • mask tokens

每一个 mask token 都是共享的,学习的向量,表示这里存在一个有待预测的缺失 patch。我们将位置 embedding 添加到这个完整集合中的所有 tokens 中 ;如果不这样做, mask tokens 就没有他们在图像中位置的信息。

解码器还有另一系列Transformer块

decoder 只用在预训练过程中使用。因此,decoder 结构可以以独立于 encoder 设计的方式灵活设计。我们用很小的 decoder进行实验,比 encoder 更小,更窄。例如,我们默认的 decoder 对每个 token 的计算量小于 encoder 的 10%。

在这种非对称设计中,完整的 token 集合只由轻量级 decoder 处理(encoder 只处理可见的 tokens),这大大减少了预训练时间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MAE_Decoder(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
num_layer=4,
num_head=3,
) -> None:
super().__init__()

self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

self.init_weight()

def init_weight(self):
trunc_normal_(self.mask_token, std=.02)
trunc_normal_(self.pos_embedding, std=.02)

def forward(self, features, backward_indexes):
T = features.shape[0]
backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
features = take_indexes(features, backward_indexes)
features = features + self.pos_embedding # 加上了位置编码的信息

features = rearrange(features, 't b c -> b t c')
features = self.transformer(features)
features = rearrange(features, 'b t c -> t b c')
features = features[1:] # remove global feature 去掉全局信息,得到图像信息

patches = self.head(features) # 用head得到patchs
mask = torch.zeros_like(patches)
mask[T:] = 1 # mask其他的像素全部设为 1
mask = take_indexes(mask, backward_indexes[1:] - 1)
img = self.patch2img(patches) # 得到 重构之后的 img
mask = self.patch2img(mask)

return img, mask

Reconstruction target

我们的 MAE 通过预测为 masked patch 预测每个像素值来重建输入,每个元素在 decoder 中的输出都是表达一个 patch 中像素值的向量。decoder 的最后一层是一个 linear projection,其输出的通道数等于一个 patch 中的像素值的个数。decoder 的输出被 reshape 来组成重建的图。

我们的 loss function 在像素空间上计算了重建图像和原图像的 MSE(平方差)。我们值计算了 masked patches 的 loss,就像 BEiT 那样。

我们还研究了一个变种,他的重建目标是每个 masked patch 的归一化像素值。具体来说,我们计算一个patch中所有像素的均值和标准差,然后用它们来归一化这个patch。在实验中,使用归一化像素作为重建目标提高了表示质量。

Simple implementation

我们的MAE预训练可以高效地实施,而且重要的是,不需要任何专门的稀疏操作。

首先,我们为每个input patch生成一个 token(先使用 Linear Projection,然后加上位置 embedding)。

接下来,我们将随机洗牌 tokens list,并根据屏蔽比率删除 list 的最后一部分。这个过程为encoder产生了一个tokens 的小的集合,和无替换地采样 patches 是一样的。

在 encoding 之后,我们向 encoded patch list 中添加一个 mask tokens list,然后反洗牌整个list(和前面的洗牌操作是相反的)以将所有的tokens和tragets对齐

decoder 在整个 list 上应用(带有添加的位置 embedding)。

如前所述,不需要进行稀疏操作。这个简单的实现引入的开销可以忽略不计,因为洗牌和反洗牌操作非常快。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class MAE_ViT(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
encoder_layer=12,
encoder_head=3,
decoder_layer=4,
decoder_head=3,
mask_ratio=0.75,
) -> None:
super().__init__()

self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

def forward(self, img):
features, backward_indexes = self.encoder(img)
predicted_img, mask = self.decoder(features, backward_indexes)
return predicted_img, mask

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!