Vision_Transformer

  • 论文:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

方法

首先结构上,我们采取的是原始Transformer模型,方便开箱即用。

整体结构如图所示

图像处理

为了处理2D图像,我们将图像\(\mathbf{x} \in \mathbb{R}^{H \times W \times C}\)重新建模为一个扁平的2D图像块序列\(\mathbf{x}_{p} \in \mathbb{R}^{N \times\left(P^{2} \cdot C\right)}\),其中\((W, H)\)是原始图像的分辨率,\(C\)是通道数,\((P,P)\)是每个图像块的分辨率,\(N=H W / P^{2}\)是生成的图像块数,它也是Transformer的有效输入序列长度。

1
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

它使用的是一个einops的拓展包,完成了上述的变换工作

Patch Embedding

接着对每个向量都做一个线性变换(即全连接层),压缩维度为D,这里我们称其为 Patch Embedding。

在代码里是初始化一个全连接层,输出维度为dim,然后将分块后的数据输入

1
2
3
4
5
self.patch_to_embedding = nn.Linear(patch_dim, dim)

# forward前向代码
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = self.patch_to_embedding(x)

Position Embedding

原始的Transformer引入了一个 Positional encoding 来加入序列的位置信息,同样在这里也引入了pos_embedding,是用一个可训练的变量替代。

1
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

class token

因为传统的Transformer采取的是类似seq2seq编解码的结构

而ViT只用到了Encoder编码器结构,缺少了解码的过程,假设你9个向量经过编码器之后,你该选择哪一个向量进入到最后的分类头呢?因此这里作者给了额外的一个用于分类的向量,与输入进行拼接。同样这是一个可学习的变量。

1
2
3
4
5
6
7
8
9
# 假设dim=128,这里shape为(1, 1, 128)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

# forward前向代码
# 假设batchsize=10,这里shape为(10, 1, 128)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 跟前面的分块为x(10,64, 128)的进行concat
# 得到(10, 65, 128)向量
x = torch.cat((cls_tokens, x), dim=1)

知道这个操作,我们也就能明白为什么前面的pos_embedding的第一维也要加1了,后续将pos_embedding也加入到x

1
x += self.pos_embedding[:, :(n + 1)]

分类

分类头很简单,加入了LayerNorm和两层全连接层实现的,采用的是GELU激活函数。

1
2
3
4
5
6
7
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
)

最终分类我们只取第一个,也就是用于分类的token,输入到分类头里,得到最后的分类结果

1
2
3
4
5
self.to_cls_token = nn.Identity()
# forward前向部分
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)

整体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0.,
emb_dropout=0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'

self.patch_size = patch_size

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)

self.to_cls_token = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
)

def forward(self, img, mask=None):
p = self.patch_size

x = self.patch_to_embedding(x)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

x = self.transformer(x, mask)

x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!