Post

Paper Reading CLIP

Paper Reading CLIP

CLIP

  • Learning Transferable Visual Models From Natural Language Supervision
  • https://github.com/OpenAI/CLIP

Main Contributions

  • Contrastive pre-training Contrastive pre-training
  • Zero-shot prediction Zero-shot prediction
  • Pseudocode

Pseudocode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed 
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]

# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# scaled pairwise cosine similarities [n, n] 
logits = np.dot(I_e, T_e.T) * np.exp(t)

# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2

Code Analysis

clip architecture

1
2
3
4
5
6
7
8
9
10
11
12
13
14
image_features = self.encode_image(image)
text_features = self.encode_text(text)

# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text

image encoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
x = self.conv1(x)  # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
# 添加cls token
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)

x = x.permute(1, 0, 2)  # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2)  # LND -> NLD

x = self.ln_post(x[:, 0, :])
# 输出结果映射到指定维度
if self.proj is not None:
    x = x @ self.proj

return x

text encoder

1
2
3
4
5
6
7
8
9
10
11
12
13
x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2)  # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2)  # LND -> NLD
x = self.ln_final(x).type(self.dtype)

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return x

Appendix

vit

1
2
3
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
    return self.resblocks(x)

residual attention

1
2
3
4
5
6
7
8
9
self.attn = nn.MultiheadAttention(d_model, n_head)
def attention(self, x: torch.Tensor):
    self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
    return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
    x = x + self.attention(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))
    return x
  • MultiheadAttention有4个required grad的Parameter
    • 将qkv映射到dim的线性层
      • (in_proj_weights, in_proj_bias)
    • 将输出映射到dim的线性层
      • (out_proj_weights, out_proj_bias)

demo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# load model and preprocess
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# preprocess -> transforms image
image = preprocess(Image.open("cat.png")).unsqueeze(0).to(device)
# text tokenize
text = clip.tokenize(["dog", "cat", "lion", "tiger"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs) 
This post is licensed under CC BY 4.0 by the author.