Post

Paper Reading SAM

Paper Reading SAM

SAM

  • Segment Anything
  • https://segment-anything.com

Main Contributions

  • data engine (dataset SA-1B)
    • phase 1: 完全人工标注的数据集, 训练sam
    • phase 2: 半自动阶段, 输入图A, sam输出一个mask, 人工在输出mask基础上, 二次标注(完善标注, 此时模型的能力还不强), 人工+模型标注的mask作为图A的ground truth, 再用图A训练sam, phase 2结束后, sam已训练完成
    • phase 3: 全自动阶段, 给新的大量无标注图片, 将sam生成的mask作为ground truth(还需要经过后处理, 如nms等操作), image和生成的mask作为SA-1B
  • segment anything model sam struct

Code Analysis

image encoder

  • 传统的ViT

prompt encoder

Sparse Embedding

Sparse Embedding

  • point
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    
    def forward_with_coords(
        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
    ) -> torch.Tensor:
        """Positionally encode points that are not normalized to [0,1]."""
        coords = coords_input.clone()
        coords[:, :, 0] = coords[:, :, 0] / image_size[1] # x坐标
        coords[:, :, 1] = coords[:, :, 1] / image_size[0] # y坐标
        return self._pe_encoding(coords.to(torch.float))  # 加位置编码 # B x N x C
    
    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
        coords = 2 * coords - 1
        # 坐标映射到transformer dim维
        coords = coords @ self.positional_encoding_gaussian_matrix
        coords = 2 * np.pi * coords
        # outputs d_1 x ... x d_n x C shape
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
    
    1
    2
    3
    4
    5
    
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
    point_embedding[labels == -1] = 0.0
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    
    • 每个point坐标[x,y], 还有对应这个point的label[-1(忽略),0(背景),1(前景)]
    • label为-1的点用作pad, 当没有box时, 添加pad
  • box
    1
    2
    3
    4
    
    coords = boxes.reshape(-1, 2, 2) # 一个框, 两个点坐标, 每个点坐标(x, y)
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight
    
    • 每个box是两个点[x0,y0,x1,y1], 转换为[n,2,2]后, 就等于把box转换成点来处理
  • text
    • 官方未实现, 第三方实现(sam改进, 变体)
    • lang sam (Grounding DINO + clip)
    • fastsam (YOLOv8-seg + clip)

Dense Embedding

Dense Embedding

  • mask
    1
    2
    3
    4
    
    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
        """Embeds mask inputs."""
        mask_embedding = self.mask_downscaling(masks)
        return mask_embedding
    
    1
    2
    3
    4
    5
    6
    
    if masks is not None:
        dense_embeddings = self._embed_masks(masks)
    else:
        dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
            bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
        ) # [1, 256, 64, 64]
    
    • 没指定mask时, 使用no_mask_embed
    • 指定mask, mask下采样成image size
    • mask和image相加作为decoder的输入

mask decoder

mask decoder mask decoder

  • 输入处理
    1
    2
    3
    4
    5
    6
    7
    
    output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) # [5,256]
    output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) # [1, 5, 256]
    tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # [1, 5 + S, 256]
    # Expand per-image data in batch direction to be per-mask
    src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    src = src + dense_prompt_embeddings
    pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    
    • src: image embedding + mask embedding
    • pos_src: image pos embedding
    • tokens: iou_tokens + mask_tokens + sparse_prompt_embeddings(points embedding + box embedding)
  • 正向传播 forward
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    
    # Cross attention block, tokens attending to image embedding
    q = queries + query_pe
    k = keys + key_pe
    attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
    queries = queries + attn_out
    queries = self.norm2(queries)
    
    # MLP block
    mlp_out = self.mlp(queries)
    queries = queries + mlp_out
    queries = self.norm3(queries)
    
    # Cross attention block, image embedding attending to tokens
    q = queries + query_pe
    k = keys + key_pe
    attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
    keys = keys + attn_out
    keys = self.norm4(keys)
    
  • 输出结果 output
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    
    iou_token_out = hs[:, 0, :] # [1, 256]
    mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # [1, 4, 256]
    # Upscale mask embeddings and predict masks using the mask tokens
    src = src.transpose(1, 2).view(b, c, h, w)
    upscaled_embedding = self.output_upscaling(src)
    hyper_in_list: List[torch.Tensor] = []
    for i in range(self.num_mask_tokens):
      yper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
    hyper_in = torch.stack(hyper_in_list, dim=1)
    b, c, h, w = upscaled_embedding.shape
    # [1, 4, 32] @ [1, 32, 256, 256]
    masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # [1, 4, 256, 256]
    # Generate mask quality predictions
    iou_pred = self.iou_prediction_head(iou_token_out) # [1, 4]
    
    • hs
      • hs提取出iou_token_out, mask_tokens_out
        • iou_token_out对应每个mask_tokens的iou score
        • mask_tokens_out与输出src做矩阵乘法, 得出输出mask, 这里的mask_tokens_out也是4维, 可以理解成做了多尺度的特征融合
    • iou_tokens: [1, transformer_dim] 用来记录输出后每个mask的iou值
    • mask_tokens: [4, transformer_dim] 用来记录输出后每个mask的特征, 为实现ambiguity-aware
      • [dim 4] = 1(记录multimask_output=false) + 3(记录multimask_output=true)
    • 这两个tokens可以理解问NLP transformer中的[cls], 用于记录输出tokens
    • src
      • 上采样4倍, 和mask_tokens_out融合后为输出mask

Automatic Mask Generation

1
2
3
4
5
6
7
8
9
10
# crop_box就是整张图的坐上和右下坐标
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[:2]
self.predictor.set_image(cropped_im)

# Get points for this crop
points_scale = np.array(cropped_im_size)[None, ::-1]
# self.point_grids就是a regular grid of 32×32 points on the full image, 即在整张图上撒了32x32个点
points_for_image = self.point_grids[crop_layer_idx] * points_scale
This post is licensed under CC BY 4.0 by the author.