Post

Paper Reading WinCLIP

Paper Reading WinCLIP

WinCLIP

  • WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation
  • 无官方代码, 参考代码链接
    • https://github.com/caoyunkang/WinClip
    • https://github.com/zqhang/Accurate-WinCLIP-pytorch

Main Contributions

  • Compositional Prompt Ensemble
  • WinCLIP(zero shot)/WinCLIP+(few shot) architecture
    • multi-scale spatial features aligned with language for zero-shot(few-shot) anomaly segmentation

Compositional Prompt Ensemble

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
template_level_prompts = [
    'a cropped photo of the {}',
    'a cropped photo of a {}',
    'a close-up photo of a {}',
    'a close-up photo of the {}',
    'a bright photo of a {}',
    'a bright photo of the {}',
    'a dark photo of the {}',
    'a dark photo of a {}',
    'a jpeg corrupted photo of a {}',
    'a jpeg corrupted photo of the {}',
    'a blurry photo of the {}',
    'a blurry photo of a {}',
    'a photo of a {}',
    'a photo of the {}',
    'a photo of a small {}',
    'a photo of the small {}',
    'a photo of a large {}',
    'a photo of the large {}',
    'a photo of the {} for visual inspection',
    'a photo of a {} for visual inspection',
    'a photo of the {} for anomaly detection',
    'a photo of a {} for anomaly detection'
]

WinCLIP architecture

WinCLIP architecture

WinCLIP+ architecture

WinCLIP+ architecture

multi-scale windows

multi-scale windows

Code Analysis

bulid text prompt tokens and embeddings

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# generate text tokens
for template_prompt in template_level_prompts:
    # normal prompts
    for normal_prompt in state_level_normal_prompts:
        phrase = template_prompt.format(normal_prompt.format(category))
        normal_phrases += [phrase]
    # abnormal prompts
    for abnormal_prompt in state_level_abnormal_prompts:
        phrase = template_prompt.format(abnormal_prompt.format(category))
        abnormal_phrases += [phrase]
# 154个normal; 88个abnormal
normal_phrases = self.tokenizer(normal_phrases).to(self.device) # [154, 77]
abnormal_phrases = self.tokenizer(abnormal_phrases).to(self.device) # [88, 77]
# text encoder
normal_text_features = self.encode_text(normal_phrases) # [154, 640]
abnormal_text_features = self.encode_text(abnormal_phrases) # [88, 640]

windows transformer

generate window masks

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
index_mask = torch.arange(self.grid_size[0] * self.grid_size[1], dtype=torch.int64) # [1,2,3...] grid * grid num
index_mask = index_mask.reshape(self.grid_size[0], self.grid_size[1]) # [15, 15]
masks = []
scale_begin_indx = []
for scale in scales:
    scale_begin_indx += [len(masks)]
    for i in range(self.grid_size[0]):
        for j in range(self.grid_size[1]):
            # 越界
            if i + scale > self.grid_size[0] or j + scale > self.grid_size[1]:
                continue
            masks += [index_mask[i:i + scale, j:j + scale]]

self.scale_begin_indx = scale_begin_indx
self.masks = masks

window masking

1
2
3
4
5
mask_xs = []
for mask in self.masks:
    mask = torch.reshape(mask, [-1]).unsqueeze(0).to(x.device)
    x_masked = torch.gather(x, dim=1, index=mask.unsqueeze(-1).repeat(N, 1, D)) # 根据index收集元素
    mask_xs += [x_masked]

image encoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
mask_xs = self.window_masking(x) # len() = 365, mask_xs[i].shape = [83, 4 or 9, 896]
pooled_list, tokens_list = [], []
for scale_index in range(len(self.scale_begin_indx)):
    if scale_index == len(self.scale_begin_indx) - 1:
        scale_xs = mask_xs[self.scale_begin_indx[scale_index]:]
    else:
        scale_xs = mask_xs[self.scale_begin_indx[scale_index]:self.scale_begin_indx[scale_index + 1]]

    mx = torch.stack(scale_xs, dim=0) # [196, 83, 4, 896]
    n_windows, n_batches, _, _ = mx.shape

    mx = mx.reshape((-1, mx.shape[2], mx.shape[3])) # [16268, 4, 896]
    # append cls token # [16268, 1, 896]
    # 用cls token记录这个windows中的全局信息
    cls_tokens = self.class_embedding.to(x.dtype) + torch.zeros(mx.shape[0], 1, mx.shape[-1], dtype=x.dtype, device=x.device)
    cls_tokens = cls_tokens + self.positional_embedding.to(mx.dtype)[0, :]
    mx = torch.cat((cls_tokens, mx), dim=1) # [16268, 5, 896]
    
    mx = mx.permute(1, 0, 2)  # NLD -> LND
    mx = self.transformer(mx)
    mx = mx.permute(1, 0, 2)  # LND -> NLD

    pooled, tokens = self._global_pool(mx) # [16268, 896], [16268, 4, 896]
    pooled = self.ln_post(pooled)

    pooled = pooled @ self.proj # [16268, 896] @ [896, 640]

    pooled = pooled.reshape((n_windows, n_batches, pooled.shape[1])) # [196, 83, 640]
    tokens = tokens.reshape((n_windows, n_batches, tokens.shape[1], tokens.shape[2])) # [196, 83, 4, 896]

    pooled_list += [p for p in pooled]
    tokens_list += [t for t in tokens]

Anomaly Score Map

language guided anomaly score map

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
scale_anomaly_scores = []
token_anomaly_scores = torch.zeros((N,self.grid_size[0] * self.grid_size[1])) # [83, 225] 225=15*15
token_weights = torch.zeros((N, self.grid_size[0] * self.grid_size[1])) # [83, 225]
for indx, (features, mask) in enumerate(zip(visual_features, self.masks)):
    normality_and_abnormality_score = (100.0 * features @ self.text_features.T).softmax(dim=-1) # [83, 640] @ [640, 2]
    normality_score = normality_and_abnormality_score[:, 0]
    abnormality_score = normality_and_abnormality_score[:, 1]
    mask = mask.reshape(-1) # [2,2] or [3,3] -> [4] or [9]
    
    cur_token_anomaly_score = torch.zeros((N, self.grid_size[0] * self.grid_size[1])) # [83, 225]
    cur_token_anomaly_score[:, mask] = (1. / normality_score).unsqueeze(1) # [83, 225]
    # cur_token_anomaly_score[:, mask] = (1. - normality_score).unsqueeze(1)
    cur_token_weight = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))
    cur_token_weight[:, mask] = 1.

    if indx in self.scale_begin_indx[1:]:
        # 记录当前尺度, 要换成下一个尺度了
        token_anomaly_scores = token_anomaly_scores / token_weights
        scale_anomaly_scores.append(token_anomaly_scores)

        # another scale, calculate from scratch
        token_anomaly_scores = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))
        token_weights = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))

    token_weights += cur_token_weight
    token_anomaly_scores += cur_token_anomaly_score # len()==scale nums

# deal with the last one
token_anomaly_scores = token_anomaly_scores / token_weights
scale_anomaly_scores.append(token_anomaly_scores)

scale_anomaly_scores = torch.stack(scale_anomaly_scores, dim=0) # [2, 83, 225]
scale_anomaly_scores = torch.mean(scale_anomaly_scores, dim=0) # [83, 225]
scale_anomaly_scores = 1. - 1. / scale_anomaly_scores

anomaly_map = scale_anomaly_scores.reshape((N, self.grid_size[0], self.grid_size[1])).unsqueeze(1) # [83, 1, 15, 15]
return anomaly_map

image based anomaly score map(for few shot)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
scale_anomaly_scores = []
token_anomaly_scores = torch.zeros((N,self.grid_size[0] * self.grid_size[1]))
token_weights = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))

cur_scale_indx = 0
cur_visual_gallery = self.visual_gallery[cur_scale_indx] # [196, 640]

for indx, (features, mask) in enumerate(zip(visual_features, self.masks)):
    # 用reference image embedding features来计算, 其他部分几乎与text guided完全相同
    normality_score = 0.5 * (1 - (features @ cur_visual_gallery.T).max(dim=1)[0]) # max( [83, 640] @ [640, 196], dim=1).shape = [83]
    mask = mask.reshape(-1)

    cur_token_anomaly_score = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))
    cur_token_anomaly_score[:, mask] = normality_score.unsqueeze(1)
    # cur_token_anomaly_score[:, mask] = (1. - normality_score).unsqueeze(1)
    cur_token_weight = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))
    cur_token_weight[:, mask] = 1.

    if indx in self.scale_begin_indx[1:]:
        # change to next scale
        cur_scale_indx += 1
        cur_visual_gallery = self.visual_gallery[cur_scale_indx]

        token_anomaly_scores = token_anomaly_scores / token_weights
        scale_anomaly_scores.append(token_anomaly_scores)

        # another scale, calculate from scratch
        token_anomaly_scores = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))
        token_weights = torch.zeros((N, self.grid_size[0] * self.grid_size[1]))

    token_weights += cur_token_weight
    token_anomaly_scores += cur_token_anomaly_score

# deal with the last one
token_anomaly_scores = token_anomaly_scores / token_weights
scale_anomaly_scores.append(token_anomaly_scores)

scale_anomaly_scores = torch.stack(scale_anomaly_scores, dim=0)
scale_anomaly_scores = torch.mean(scale_anomaly_scores, dim=0)

anomaly_map = scale_anomaly_scores.reshape((N, self.grid_size[0], self.grid_size[1])).unsqueeze(1)
return anomaly_map

generate reference image features

1
2
3
4
5
6
7
8
9
10
# reference image经过image encoder, 和query image的处理是一样的
self.visual_gallery = []
visual_features = self.encode_image(normal_images)
for scale_index in range(len(self.scale_begin_indx)):
    if scale_index == len(self.scale_begin_indx) - 1:
        scale_features = visual_features[self.scale_begin_indx[scale_index]:]
    else:
        scale_features = visual_features[self.scale_begin_indx[scale_index]:self.scale_begin_indx[scale_index+1]]

    self.visual_gallery += [torch.cat(scale_features, dim=0)]

trick

1
2
3
4
5
# 在zero shot下visual_anomaly_map = textual_anomaly_map
if self.visual_gallery is not None:
    visual_anomaly_map = self.calculate_visual_anomaly_score(visual_features)
else:
    visual_anomaly_map = textual_anomaly_map
1
2
# 计算anomaly_map
anomaly_map = 1. / (1. / textual_anomaly_map + 1. / visual_anomaly_map)

Thoughts

  • WinCLIP无需训练, prompts手动设计
  • zero shot/few shot在这个模型架构下其实统一了, 更符合AD任务的真实场景
  • template prompts并不能反应全部情况, 而且太多重叠的部分
    • 会通过template生成100+个prompt, 这其中会包含a photo of [small/middle\lager] {class}, 有无必要性呢?
  • 生成的score map是像素级别下的, 因此完全可以考虑将分割模块加入其中
  • 用滑动窗口计算局部attention, 可不可以用Swin Transformer计算局部呢?
This post is licensed under CC BY 4.0 by the author.