Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: arydshln
  • failed: kotex
  • failed: epic

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license
arXiv:2403.09199v1 [cs.CV] 14 Mar 2024

Customizing Segmentation Foundation Model
via Prompt Learning for Instance Segmentation

Hyung-Il Kim
ETRI
Daejeon, Korea
hikim@etri.re.kr
Both authors equally contributed.
   Kimin Yun11footnotemark: 1
ETRI
Daejeon, Korea
kimin.yun@etri.re.kr
   Jun-Seok Yun
KITECH
Daegu, Korea
yunjs@kitech.ac.kr
   Yuseok Bae
ETRI
Daejeon, Korea
baeys@etri.re.kr
Abstract

Recently, foundation models trained on massive datasets to adapt to a wide range of domains have attracted considerable attention and are actively being explored within the computer vision community. Among these, the Segment Anything Model (SAM) stands out for its remarkable progress in generalizability and flexibility for image segmentation tasks, achieved through prompt-based object mask generation. However, despite its strength, SAM faces two key limitations when applied to customized instance segmentation that segments specific objects or those in unique environments not typically present in the training data: 1) the ambiguity inherent in input prompts and 2) the necessity for extensive additional training to achieve optimal segmentation. To address these challenges, we propose a novel method, customized instance segmentation via prompt learning tailored to SAM. Our method involves a prompt learning module (PLM), which adjusts input prompts into the embedding space to better align with user intentions, thereby enabling more efficient training. Furthermore, we introduce a point matching module (PMM) to enhance the feature representation for finer segmentation by ensuring detailed alignment with ground truth boundaries. Experimental results on various customized instance segmentation scenarios demonstrate the effectiveness of the proposed method.

1 Introduction

To identify and segment pixels belonging to each object instance, instance segmentation technology has been considered a crucial component for high-level scene understanding. In addition to general instance segmentation trained with common object instances (e.g., COCO dataset [27]), instance segmentation for segmenting specific objects (e.g., face, salient object) has been widely studied for various real-world applications: autonomous driving [14, 36], medical image segmentation [9, 3, 45], image editing [44, 28].

Refer to caption
Figure 1: Our proposed method mitigates SAM’s sensitivity to input prompts by adjusting prompt features in the embedding space to align with class-wise object mask-based user intentions via a prompt learning module (PLM). Additionally, we enhance the feature representation for finer object segmentation through training with a point matching module (PMM).

Recent advances in deep learning have led to significant progress in instance segmentation algorithms. Inspired by the Faster R-CNN [35], two-stage object detector-based instance segmentation algorithms (e.g., Mask R-CNN [15] and Mask Scoring R-CNN [18]) have been introduced, followed by one-stage object detector-based methods (e.g., YOLACT [4]). To address the limitations of relying on pre-defined anchors, anchor-free algorithms [40, 24] have been introduced. Furthermore, the integration of attention mechanisms from Transformer architectures [38, 12] has led to the development of query-based instance segmentation methods, such as ISTR [17], QueryInst [13], and SOLQ [11]. Despite these technological advances, instance segmentation still faces significant challenges, particularly in adapting to a variety of environments and improving training efficiency while minimizing the reliance on costly mask-annotated data.

More recently, foundation models [5], trained with vast datasets to adapt to a wide range of downstream tasks, have received tremendous attention in the field of computer vision [34, 19, 30]. Among them, a segmentation foundation model known as the ‘Segment Anything Model (SAM)’ [22] has received the spotlight regarding generalizability and flexibility in image segmentation. Trained on vast datasets, SAM is designed as a promptable model, adept at generating segmentation masks for object regions in response to user prompts, such as points or text. Despite its powerful generalization capability and flexibility, the SAM still faces challenges in segmenting specific objects or those in unique environments not covered in its training data. As depicted in Fig. 1, SAM’s sensitivity to input prompts causes a significant issue. The inherent ambiguity in these prompts can result in substantial variations, affecting both the segmented object type and the segmentation mask quality. Moreover, optimizing the segmentation model for uniquely shaped objects in particular environments requires extensive training with additional large-scale datasets.

To tackle these challenges, we focus on customizing the SAM to better reflect user intention in instance segmentation tasks, especially when users collect datasets for specific segmentation targets. SAM is a promptable model responsive to user intention, as illustrated in Fig. 1. Its sensitivity to input prompts can often lead to repetitive attempts or even failures in precisely segmenting the desired objects. To address this, our approach provides a ‘customized segmentation’ with an additional learning module to the SAM, utilizing datasets collected by users with mask annotations for the specific objects the user aims to segment.

Specifically, to mitigate the issue of sensitivity to input prompts, we devise a prompt learning module (PLM). This module transforms input prompts within the embedding space to accurately reflect the user’s intentions for customized segmentation. A key advantage of the PLM is its plug-and-play capability, which allows for efficient customization of the segmentation model. By selectively training only the PLM while keeping the rest of the model weights frozen, it enables effective adaptation without extensive training. Additionally, we introduce a point matching module (PMM) to enhance the segmentation model’s performance further. PMM improves feature representation for finer segmentation by focusing on features to object boundary points. We validate the efficacy of our proposed method through experiments focused on customized instance segmentation tasks: facial part segmentation, outdoor banner segmentation, and license plate segmentation. Our findings consistently demonstrate the effectiveness of our approach in instance segmentation tasks tailored to user intention.

The major contributions of our work are as follows:

  • Our method effectively tackles the problem of prompt sensitivity in the SAM, leading to more stable instance segmentation that adheres to user-intended object shapes. This approach ensures the segmentation process is not only accurate but also aligned with specific user requirements.

  • Our approach leverages a plug-and-play prompt learning module. This allows for efficient customization without necessitating comprehensive fine-tuning. Notably, this method preserves the foundational model’s generalizability, making it a versatile solution across diverse segmentation tasks.

  • Furthermore, a point matching module is devised to enhance features for boundary points on a mask, which contributes to finer segmentation.

2 Related Work

2.1 Segment Anything

As a foundation model in computer vision, SAM [22] has recently shown remarkable zero-shot image segmentation performance by harnessing the power of a large-scale dataset containing over 1B1𝐵1B1 italic_B mask data. This model’s outstanding generalization capability has the potential to be applied to image understanding tasks across diverse environments [32, 41, 8, 7, 39, 42, 10, 43, 6]. For instance, MedSAM [32] adapted SAM for medical image segmentation using 1M1𝑀1M1 italic_M medical image-mask pairs. In the 3D image understanding, Cen et al. [7] proposed 3D object segmentation through cross-view self-prompting and mask inverse rendering, utilizing single-view 2D masks generated by SAM. The tracking anything module (TAM) [42] was designed to assess and refine the quality of SAM-initiated masks for video object segmentation, aiming to address SAM’s inconsistencies in mask estimation across video frames. In addition to the task-specific approaches, there have been efforts to analyze and improve SAM [33, 47, 37, 49]. SAM-OCTA [8] fine-tuned the SAM encoder’s parameters using the low-rank adaptation [16] to adapt specific datasets while preserving the semantic understanding of SAM. The medical SAM adapter [41] was proposed to integrate medical-specific domain knowledge into the SAM using a simple adapter.

While SAM’s adoption across various applications and its effectiveness have been recognized, fully exploiting the foundation model’s potential requires further exploration in additional training and prompting strategies. In this context, we address the issue of the sensitivity of the SAM to input prompts and the customization of the foundation model for specific object segmentation. To this end, we propose a prompt learning-based customization method, enabling specific object segmentation while building upon SAM’s generalization capabilities.

2.2 Prompt Tuning

Foundation models have been utilized for specific downstream tasks through retraining or fine-tuning. These methods typically demand considerable computational resources, large-scale datasets, and significant time. To address these issues, the concept of prompt tuning [29, 26, 25, 31] has received attention, particularly in natural language processing (NLP). This approach utilizes the inherent knowledge of foundation models without necessitating the retraining of the entire model. Inspired by these studies in NLP, visual prompt tuning (VPT) [20] has demonstrated remarkable adaptation performance in computer vision by training minimal prompt parameters. VPT has been shown to be superior to other adaptation approaches, such as full model training or head-oriented tuning, in terms of both effectiveness and efficiency. Recently, prompt tuning approaches based on SAM have been actively explored. PerSAM [46] has proposed to personalize segmentation, aiming to identify areas in images that share the same foreground as a user-provided image. It introduces additional inputs to SAM’s decoder, such as target-guided attention and prompting mechanisms, thereby enabling personalized segmentation. HQ-SAM [21] has tackled the degradation of mask quality in complex structures by using a learnable output token, enhancing mask details through the aggregation of global-local features and fine-tuning small parameters with fine-grained mask datasets.

In this regard, we propose a prompt learning method designed to segment specific object instances by customizing the segmentation foundation model. Our approach freezes the SAM’s model parameters, focusing on training only the proposed prompt learning module for customization. The proposed method dynamically modulates the prompt features depending on input for task customization in embedding space, by learning the shape prior for segmentation. This approach not only streamlines the process but also facilitates efficient, plug-and-play customization for customized segmentation tasks.

3 Ambiguity in Prompts

In this section, we investigate the prompt ambiguity problem in SAM, one of the foundation models targeted in this paper.

Refer to caption
(a)
Refer to caption
(b)
Figure 2: (a) Instance segmentation results with the SAM for ambiguous input prompts and (b) visualization of IoU maps for multiple masks estimated by the SAM, where each pixel denotes the IoU value between the GT mask and the estimated mask. Note that each pixel location means the location of the input prompt.

SAM’s architecture comprises an image encoder, a prompt encoder, and a mask decoder. Given the input image (x𝑥xitalic_x), SAM estimates the segmentation mask (m𝑚mitalic_m) of the object according to the user’s input prompt (p𝑝pitalic_p) as follows:

m=D(EI(x),EP(p)),𝑚𝐷subscript𝐸𝐼𝑥subscript𝐸𝑃𝑝m=D\left(E_{I}(x),E_{P}(p)\right),italic_m = italic_D ( italic_E start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT ( italic_x ) , italic_E start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( italic_p ) ) , (1)

where EIsubscript𝐸𝐼E_{I}italic_E start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT, EPsubscript𝐸𝑃E_{P}italic_E start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT, and D𝐷Ditalic_D denote the image encoder, prompt encoder, and mask decoder, respectively. As pointed out in [22], when p𝑝pitalic_p is given ambiguously (e.g., located at the boundary of an object or in an area where two different objects overlap), semantically different objects can be segmented. To deal with this problem, the approach in [22] generates three output masks based on the confidence score (i.e., the estimated intersection over union (IoU)), using a small number of output tokens. This model was trained through the backpropagation of the minimum loss between the ground truth (GT) and the three estimated masks.

Despite SAM’s flexibility to infer multiple masks, there are still issues to consider when employing the segmentation foundation model in customized instance segmentation scenarios. To investigate the prompt ambiguity problem, let us consider two segmentation scenarios: 1) face segmentation, where the user aims to segment only the skin region, and 2) outdoor banner segmentation, focusing on segmenting the whole banner area. As illustrated in Fig. 1(a), if the prompt is near the face’s border, the segmented area may include unintended parts, like hair. Even in the mask results about the whole face among multiple masks, we observe that the face including the hair is segmented, rather than the skin desired by the user.

Refer to caption
Figure 3: Overall framework of the proposed method. Building upon the SAM (left) with two encoders and a mask decoder, the proposed method (right) introduces two additional modules. The prompt learning module (PLM) ϕitalic-ϕ\phiitalic_ϕ adjusts the prompt feature so that the user’s desired object can be segmented well. In addition, the point matching module (PMM) φ𝜑\varphiitalic_φ enables finer segmentation through learning to minimize the distance between the GT points and estimated points by φ𝜑\varphiitalic_φ.

Similarly, in outdoor banner segmentation, depending on the prompt’s placement, unintended objects like parts of a person can be segmented, deviating from the user’s intention. To further examine the sensitivity of input prompts, we visualize IoU values comparing the GT mask with the SAM output mask by moving the prompt to all positions in Fig. 1(b), where red indicates an IoU value close to 1111, and blue indicates a value near 00. As expected, it shows that successful segmentation can be made when the input prompt exists in a place where there are no specific object instances that differ from the intention (e.g., the unobstructed skin area of the face or the flat banner area without people or text). This is mainly because the SAM predicts the most probable contour containing the prompt, rather than aligning with a specific user intention. Consequently, users need to make fine adjustments to their input prompts to achieve effective segmentation. To streamline this process, we introduce a customized instance segmentation approach that refines prompts for more accurate segmentation of the user’s intended objects by leveraging the generalization capabilities of SAM.

4 Proposed Method

In this paper, we propose a method for customized instance segmentation through prompt learning applied to a segmentation foundation model, i.e., SAM. To address the sensitivity issue to input prompts, we devise a prompt learning module (PLM). This module transforms input prompts within the embedding space, reflecting the user’s intention for customized segmentation. Additionally, we introduce a point matching module (PMM) to enhance instance segmentation performance by focusing on features related to object boundary points, which aids in matching the segmentation more closely to the ground truth boundary.

As shown in Fig. 3, given an input image (xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) and a prompt (pisubscript𝑝𝑖p_{i}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT), image and prompt features (fIisuperscriptsubscript𝑓𝐼𝑖f_{I}^{i}italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT and fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT) are extracted by the image and prompt encoders (EIsubscript𝐸𝐼E_{I}italic_E start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT and EPsubscript𝐸𝑃E_{P}italic_E start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT), i.e., fIi=EI(xi)superscriptsubscript𝑓𝐼𝑖subscript𝐸𝐼subscript𝑥𝑖f_{I}^{i}=E_{I}\left(x_{i}\right)italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_E start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) and fPi=EP(pi)superscriptsubscript𝑓𝑃𝑖subscript𝐸𝑃subscript𝑝𝑖f_{P}^{i}=E_{P}\left(p_{i}\right)italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_E start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). As discussed in Section 3, the PLM is designed to address the issue of ambiguous input prompts by learning to transform the fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT. The training process of PLM is driven by a dataset containing instances of the user’s target objects, teaching the PLM how to modify the fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT effectively. It does this by estimating necessary adjustments within the embedding space, which are informed by both fIisuperscriptsubscript𝑓𝐼𝑖f_{I}^{i}italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT and fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT. Through this training, the PLM learns the optimal transformation of the prompt feature as a residual form, aligning it more closely with the specific segmentation needs of the user. Next, based on this transformed prompt feature and original image embedding, the object instance mask misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is obtained by a mask decoder D𝐷Ditalic_D. Note that the PMM is introduced in the training phase to improve the quality of the mask generated by D𝐷Ditalic_D. By focusing on features related to boundary points extracted by the mask decoder, the PMM improves the precision of the segmentation contours. To maintain the generalization capability of the segmentation foundation model, we keep the architecture of the image encoder, prompt encoder, and mask decoder unchanged, freezing their pre-trained weights.

4.1 Prompt Learning Module

The PLM ϕitalic-ϕ\phiitalic_ϕ aims to adjust prompts, ensuring effective segmentation of the desired object by the user, regardless of the initial prompt provided for instance segmentation. Our approach primarily utilizes sparse prompts (e.g., point or bounding box) that the users can easily provide. Instead of making adjustments in a low-dimensional space (pi2subscript𝑝𝑖superscript2p_{i}\in\mathbb{R}^{2}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT) with limited information, the PLM operates in a higher-dimensional space (fPi256superscriptsubscript𝑓𝑃𝑖superscript256f_{P}^{i}\in\mathbb{R}^{256}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 256 end_POSTSUPERSCRIPT) for prompt feature adjustments.

The PLM, depicted in Fig. 3, consists of three consecutive operations: self-attention of prompt features, prompt-to-image attention (with the attended prompt features as queries), and multi-layer perception (MLP). The PLM, utilizing multi-head (T𝑇Titalic_T) attention, calculates the necessary adjustment for the prompts in the embedding space based on the image features fIisuperscriptsubscript𝑓𝐼𝑖f_{I}^{i}italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT and prompt features fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT. This adjustment (offset) is represented as fPisuperscriptsubscript𝑓𝑃𝑖\bigtriangleup f_{P}^{i}△ italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT. Specifically, in the self-attention block, the fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT is embedded based on the attention operation between tokens in the fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT. Next, in the prompt-to-image attention block, an attention operation is performed using the prompt features embedded by the self-attention as the query and the fIisuperscriptsubscript𝑓𝐼𝑖f_{I}^{i}italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT as the key/value. Then, the MLP layer updates each token of the attended embedding. Each attention block includes positional embedding to its inputs, and a layer normalization [2] operation follows each block. In summary, the PLM estimates the prompt feature change offset fIisuperscriptsubscript𝑓𝐼𝑖\triangle f_{I}^{i}△ italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, given the fIisuperscriptsubscript𝑓𝐼𝑖f_{I}^{i}italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT and fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, to enhance a segmentation result:

fPi=ϕ(fPi,fIi).superscriptsubscript𝑓𝑃𝑖italic-ϕsuperscriptsubscript𝑓𝑃𝑖superscriptsubscript𝑓𝐼𝑖\triangle f_{P}^{i}=\phi\left(f_{P}^{i},f_{I}^{i}\right).△ italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_ϕ ( italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) . (2)

Then, the transformed prompt feature f~Pisuperscriptsubscript~𝑓𝑃𝑖\tilde{f}_{P}^{i}over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT is obtained by adding the estimated offset fPisuperscriptsubscript𝑓𝑃𝑖\triangle f_{P}^{i}△ italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT to the fPisuperscriptsubscript𝑓𝑃𝑖f_{P}^{i}italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, i.e., f~Pi=fPi+fPisuperscriptsubscript~𝑓𝑃𝑖superscriptsubscript𝑓𝑃𝑖superscriptsubscript𝑓𝑃𝑖\tilde{f}_{P}^{i}=f_{P}^{i}+\triangle f_{P}^{i}over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT + △ italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT. Based on f~Pisuperscriptsubscript~𝑓𝑃𝑖\tilde{f}_{P}^{i}over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT and fIisuperscriptsubscript𝑓𝐼𝑖f_{I}^{i}italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, the segmentation mask m~isubscript~𝑚𝑖\tilde{m}_{i}over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is then estimated by the mask decoder, m~i=D(fIi,f~Pi)subscript~𝑚𝑖𝐷superscriptsubscript𝑓𝐼𝑖superscriptsubscript~𝑓𝑃𝑖\tilde{m}_{i}=D(f_{I}^{i},\tilde{f}_{P}^{i})over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_D ( italic_f start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ). That is, the ambiguity problem is handled by learning to estimate the optimal prompt embedding change fPisuperscriptsubscript𝑓𝑃𝑖\triangle f_{P}^{i}△ italic_f start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT from random prompt samples in training.

4.2 Point Matching Module

To enhance the mask decoder’s ability to estimate the segmentation mask m~isubscript~𝑚𝑖\tilde{m}_{i}over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT using the adjusted prompt feature, we introduce the PMM φ𝜑\varphiitalic_φ, during training, focusing on refining features, particularly those associated with mask details and quality. Motivated by TextBPN++ [48], the proposed PMM is designed to guide mask edge points toward the corresponding GT points. In other words, this involves learning the offset required to align with these GT points as an auxiliary task. Unlike TextBPN++, which updates points iteratively, our module estimates these points through an end-to-end training approach. However, since extracting boundary points from the mask via contour fitting is non-differentiable, this process cannot be trained through standard backpropagation. To deal with this issue, we utilize points on the mask edge of the GT mask, augmented with jittering for training purposes. Suppose that the set of the GT points and its jittered points are 𝒢i={ci1,ci2,,ciK}subscript𝒢𝑖superscriptsubscript𝑐𝑖1superscriptsubscript𝑐𝑖2superscriptsubscript𝑐𝑖𝐾\mathcal{G}_{i}=\left\{c_{i}^{1},c_{i}^{2},\cdots,c_{i}^{K}\right\}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = { italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ⋯ , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT } and 𝒢i={ci1,ci2,,ciK}superscriptsubscript𝒢𝑖superscriptsubscript𝑐𝑖absent1superscriptsubscript𝑐𝑖absent2superscriptsubscript𝑐𝑖absent𝐾\mathcal{G}_{i}^{\ast}=\left\{c_{i}^{\ast 1},c_{i}^{\ast 2},\cdots,c_{i}^{\ast K% }\right\}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = { italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ 1 end_POSTSUPERSCRIPT , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ 2 end_POSTSUPERSCRIPT , ⋯ , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ italic_K end_POSTSUPERSCRIPT }, respectively. And the feature extracted by the two-way attention block in the mask decoder, denoted as fDTisuperscriptsubscript𝑓subscript𝐷𝑇𝑖f_{D_{T}}^{i}italic_f start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, is used for extracting point features. However, since the feature map size fDTiHf×Wf×Csuperscriptsubscript𝑓subscript𝐷𝑇𝑖superscriptsubscript𝐻𝑓subscript𝑊𝑓𝐶f_{D_{T}}^{i}\in\mathbb{R}^{H_{f}\times W_{f}\times C}italic_f start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT × italic_W start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT × italic_C end_POSTSUPERSCRIPT differs from the jittered points’ coordinates 𝒢isuperscriptsubscript𝒢𝑖\mathcal{G}_{i}^{\ast}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT in image space xiHI×WIsubscript𝑥𝑖superscriptsubscript𝐻𝐼subscript𝑊𝐼x_{i}\in\mathbb{R}^{H_{I}\times W_{I}}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT × italic_W start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (HI>Hf,WI>Wfformulae-sequencesubscript𝐻𝐼subscript𝐻𝑓subscript𝑊𝐼subscript𝑊𝑓H_{I}>H_{f},\,W_{I}>W_{f}italic_H start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT > italic_H start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT > italic_W start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT), we align the coordinate of 𝒢isuperscriptsubscript𝒢𝑖\mathcal{G}_{i}^{\ast}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to the feature map space through interpolation. This step allows us to collect a feature matrix 𝒲iK×Csubscript𝒲𝑖superscript𝐾𝐶\mathcal{W}_{i}\in\mathbb{R}^{K\times C}caligraphic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K × italic_C end_POSTSUPERSCRIPT for each point in 𝒢isuperscriptsubscript𝒢𝑖\mathcal{G}_{i}^{\ast}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. For convenience, we define this function (i.e., interpolation and point feature extraction) as c()𝑐c(\cdot)italic_c ( ⋅ ), i.e., 𝒲i=c(fDTi,𝒢i)subscript𝒲𝑖𝑐superscriptsubscript𝑓subscript𝐷𝑇𝑖superscriptsubscript𝒢𝑖\mathcal{W}_{i}=c(f_{D_{T}}^{i},\mathcal{G}_{i}^{\ast})caligraphic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ( italic_f start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ): Based on the 𝒲isubscript𝒲𝑖\mathcal{W}_{i}caligraphic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for 𝒢isuperscriptsubscript𝒢𝑖\mathcal{G}_{i}^{\ast}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, the boundary transformer [48] with the encoder-decoder structure is utilized to estimate the refined boundary points 𝒢~isubscript~𝒢𝑖\mathcal{\tilde{G}}_{i}over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Following the design in [48], the encoder consists of three transformer blocks with residual connections, and the decoder employs a three-layered 1×1111\times 11 × 1 convolution with ReLU activation. In summary, the PMM φ𝜑\varphiitalic_φ refines 𝒢isuperscriptsubscript𝒢𝑖\mathcal{G}_{i}^{\ast}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT from 𝒲isubscript𝒲𝑖\mathcal{W}_{i}caligraphic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, thus returning 𝒢~i=φ(𝒲i)subscript~𝒢𝑖𝜑subscript𝒲𝑖\mathcal{\tilde{G}}_{i}=\varphi(\mathcal{W}_{i})over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_φ ( caligraphic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). The set 𝒢~isubscript~𝒢𝑖\tilde{\mathcal{G}}_{i}over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of the refined boundary points is used for training the proposed network.

4.3 Training

Our network integrates two proposed modules with the SAM for end-to-end training. In this setup, all parameters within the SAM are frozen. The objective function \mathcal{L}caligraphic_L for training consists of the sum of two loss functions like:

=1Ni=1N(seg(mi,m~i)+λpm(𝒢i,𝒢~i)),1𝑁superscriptsubscript𝑖1𝑁subscript𝑠𝑒𝑔subscript𝑚𝑖subscript~𝑚𝑖𝜆subscript𝑝𝑚subscript𝒢𝑖subscript~𝒢𝑖\mathcal{L}=\frac{1}{N}\sum_{i=1}^{N}\left(\mathcal{L}_{seg}(m_{i},\tilde{m}_{% i})+\lambda\mathcal{L}_{pm}(\mathcal{G}_{i},\tilde{\mathcal{G}}_{i})\right),caligraphic_L = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_s italic_e italic_g end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_λ caligraphic_L start_POSTSUBSCRIPT italic_p italic_m end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , (3)

where segsubscript𝑠𝑒𝑔\mathcal{L}_{seg}caligraphic_L start_POSTSUBSCRIPT italic_s italic_e italic_g end_POSTSUBSCRIPT and pmsubscript𝑝𝑚\mathcal{L}_{pm}caligraphic_L start_POSTSUBSCRIPT italic_p italic_m end_POSTSUBSCRIPT denote the loss functions for segmentation and point matching, respectively. And, λ𝜆\lambdaitalic_λ is a balancing hyperparameter. Regarding segsubscript𝑠𝑒𝑔\mathcal{L}_{seg}caligraphic_L start_POSTSUBSCRIPT italic_s italic_e italic_g end_POSTSUBSCRIPT, we follow the loss function detailed in [22], which computes a linear combination of focal loss and dice loss at a 20:1 ratio for the GT mask misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and the estimated mask m~isubscript~𝑚𝑖\tilde{m}_{i}over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Besides, IoU loss and mask loss to predict the IoU score itself are used to compute the segsubscript𝑠𝑒𝑔\mathcal{L}_{seg}caligraphic_L start_POSTSUBSCRIPT italic_s italic_e italic_g end_POSTSUBSCRIPT. The loss function pmsubscript𝑝𝑚\mathcal{L}_{pm}caligraphic_L start_POSTSUBSCRIPT italic_p italic_m end_POSTSUBSCRIPT is designed to enhance the precision of mask estimation through the PMM φ𝜑\varphiitalic_φ, which is defined as

pm(𝒢i,𝒢~i)=1Kk=1K(infc~ij𝒢i~{cikc~ij2}).subscript𝑝𝑚subscript𝒢𝑖subscript~𝒢𝑖1𝐾superscriptsubscript𝑘1𝐾superscriptsubscript~𝑐𝑖𝑗~subscript𝒢𝑖infsuperscriptnormsuperscriptsubscript𝑐𝑖𝑘superscriptsubscript~𝑐𝑖𝑗2\mathcal{L}_{pm}(\mathcal{G}_{i},\tilde{\mathcal{G}}_{i})=\frac{1}{K}\sum_{k=1% }^{K}\left(\underset{\tilde{c}_{i}^{j}\in\tilde{\mathcal{G}_{i}}}{\mathrm{inf}% }\,\{\|c_{i}^{k}-\tilde{c}_{i}^{j}\|^{2}\}\right).caligraphic_L start_POSTSUBSCRIPT italic_p italic_m end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( start_UNDERACCENT over~ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ∈ over~ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_UNDERACCENT start_ARG roman_inf end_ARG { ∥ italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over~ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } ) . (4)

This function computes the distance between the set (𝒢isubscript𝒢𝑖\mathcal{G}_{i}caligraphic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) of points for the GT and the points (𝒢~isubscript~𝒢𝑖\tilde{\mathcal{G}}_{i}over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) predicted by the φ𝜑\varphiitalic_φ.

5 Experiments

5.1 Experimental Settings

In this section, we conform to the single-point valid mask evaluation method as established in the original SAM study. Our analysis spans three tasks: facial part segmentation, outdoor banner segmentation, and license plate segmentation. The evaluation focuses on segmenting an object from a single foreground point. Given that SAM can predict three masks, we assess both approaches: the model’s most confident mask and scenarios involving multiple masks. In cases of multiple masks, we report the highest IoU, assuming hypothetical user selection from the three masks. We refer to this as SAM (oracle), consistent with the terminology in [22]. We employ the standard mean IoU (mIoU) metric to evaluate the match between predicted and GT masks. We report results in four scenarios: standard results for SAM+PLM and SAM+PLM+PMM, plus their oracle versions (SAM+PLM (oracle) and SAM+PLM+PMM (oracle)), showcasing the effectiveness of our methods under regular and ideal selection conditions.

Table 1: The number of training and validation data of each part in the CelebA-HQ.
  Part Name ‘skin’ ‘nose’ ‘eye_g’ ‘l_eye’ ‘r_eye’ ‘l_brow’ ‘r_brow’ ‘l_ear’ ‘r_ear’ ‘mouth’ ‘u_lip’ ‘l_lip’ ‘hair’ ‘hat’ ‘ear_r’ ‘neck_l’ ‘neck’ ‘cloth’
  Training 21,000 20,999 979 20,537 20,531 20,368 20,310 10,897 9,903 12,487 20,913 20,924 20,551 880 5,790 1,307 20,406 12,009
Validation 6,000 6,000 380 5,809 5,811 5,784 5,778 3,325 2,998 3,452 5,978 5,983 5,834 266 1,497 319 5,840 3,763
 
Table 2: Segmentation performance for facial part. It represents the mean IoU for each of the 18 individual parts.
  Method ‘skin’ ‘nose’ ‘eye_g’ ‘l_eye’ ‘r_eye’ ‘l_brow’ ‘r_brow’ ‘l_ear’ ‘r_ear’
  SAM 71.37 12.62 15.57 50.18 55.83 24.82 22.71 44.93 47.97
SAM-F 74.77 8.66 23.77 47.07 45.70 13.36 11.74 45.72 46.42
Proposed (SAM+PLM) 84.38 82.09 90.10 76.93 77.57 61.11 63.76 76.92 74.37
Proposed (SAM+PLM+PMM) 84.95 82.58 90.65 77.66 78.21 61.74 63.80 77.07 76.18
SAM (oracle) 80.76 14.90 35.88 60.30 59.93 25.74 24.30 49.58 50.00
Proposed (SAM+PLM+PMM) (oracle) 87.56 83.17 90.80 81.22 81.18 64.21 66.72 77.11 77.07
  Method ‘mouth’ ‘u_lip’ ‘l_lip’ ‘hair’ ‘hat’ ‘ear_r’ ‘neck_l’ ‘neck’ ‘cloth’ Average
  SAM 55.85 8.57 20.35 20.35 58.95 28.62 29.09 33.91 29.14 35.05
SAM-F 30.80 11.03 21.23 24.12 64.54 27.51 31.76 15.24 29.50 31.83
Proposed (SAM+PLM) 76.44 70.80 73.94 73.54 82.66 34.77 54.01 51.47 76.63 71.19
Proposed (SAM+PLM+PMM) 76.68 70.88 74.10 73.88 83.34 34.17 55.74 51.59 76.84 71.67
SAM (oracle) 62.75 17.17 31.51 52.36 77.29 31.39 36.39 30.87 39.44 43.36
Proposed (SAM+PLM+PMM) (oracle) 78.92 72.85 76.69 82.78 86.81 39.30 62.11 56.34 79.64 74.69

In a training phase, to align with how users typically interact with the segmentation model, often choosing plausible foreground positions over exact centers, we refined our training approach. An arbitrary point from the training mask was selected as the input prompt, with the corresponding GT mask as the target output. We employed a probabilistic function to determine prompt positions, prioritizing those close to the center. During the testing phase, we evaluated the model’s accuracy using the center point of the mask as the input. For comparison, we implemented a method named SAM-F by applying the scale-aware fine-tuning proposed in PerSAM [46]. This approach, adapted from the SAM, utilizes three scales of masks, M1subscript𝑀1M_{1}italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, M2subscript𝑀2M_{2}italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and M3subscript𝑀3M_{3}italic_M start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, and employs two learnable parameters, w1subscript𝑤1w_{1}italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and w2subscript𝑤2w_{2}italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, to finetune the final mask such that M=w1M1+w2M2+(1w1w2)M3𝑀subscript𝑤1subscript𝑀1subscript𝑤2subscript𝑀21subscript𝑤1subscript𝑤2subscript𝑀3M=w_{1}M_{1}+w_{2}M_{2}+(1-w_{1}-w_{2})M_{3}italic_M = italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + ( 1 - italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_M start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. Through this, we aim to evaluate the capability to achieve a user-desired mask by only combining the output scales of SAM.

Our model builds upon the pre-trained SAM, i.e., MAE pre-trained ViT-H image encoder. Our updates are confined to the PLM (1.6M parameters) and the PMM (1.2M parameters) while keeping the other parameters (641M of SAM-ViT-H) frozen. In testing, only the parameters of the PLM ϕitalic-ϕ\phiitalic_ϕ (except φ𝜑\varphiitalic_φ) are utilized. However, the parameters for φ𝜑\varphiitalic_φ can also be employed in certain scenarios through a two-step process (please refer to the discussion of Fig. 5). For our model optimization, we set the learning rate to 0.000010.000010.000010.00001 and utilized an AdamW optimizer with a weight decay of 0.010.010.010.01. The model was trained on eight A40 GPUs over 20,0002000020,00020 , 000 steps with a batch size of 10101010 and an image size of 1,024×1,024102410241,024\times 1,0241 , 024 × 1 , 024 pixels. A learning rate scheduler incorporating a warmup phase incrementally increased the learning rate for the initial 250250250250 steps. After the warmup period, we applied a stepped decay approach to the learning rate, reducing at predetermined milestones by a factor of gamma, set to 0.10.10.10.1 in our case. Specifically, the learning rate was reduced at two stages: after 66.66766.66766.66766.667% and 86.66686.66686.66686.666% of the total number of training batches. And, λ𝜆\lambdaitalic_λ in Eq. 3 and the number of heads (T𝑇Titalic_T) in ϕitalic-ϕ\phiitalic_ϕ were set to 1111 and 8888, respectively.

5.2 Facial Part Segmentation

For the facial part segmentation, we utilize the CelebA-HQ dataset [23] which provides 18181818 different face part masks. Figure 3(a) shows sample images overlaid with GT masks for each facial part. Given the detailed mask for facial features from this dataset, our method can be trained to segment facial parts specifically as desired by the user. We used 21k21𝑘21k21 italic_k images for training and 6k6𝑘6k6 italic_k images for validation. Despite using identical images, the occlusions of facial parts lead to different annotation counts for each facial part. The specific counts for each part are presented in Table 1.

Refer to caption
(a) Ground truth
Refer to caption
(b) SAM
Refer to caption
(c) SAM-F
Refer to caption
(d) SAM+PLM
Refer to caption
(e) SAM+PLM+PMM
Figure 4: Qualitative results on facial part segmentation. Each column, sequentially from left to right, represents the following parts: skin, nose, eye glasses, right brow, upper lip, hair, left ear, right eye, left brow, and lower lip.

Quantitative Results. Table 2 presents the single-point segmentation results for each facial part, with the ‘Average’ indicating the mean mIoU for all 18 parts. Performance is evaluated using the mask of highest confidence for both SAM and our method. For SAM, the highest IoU mask against the GT is denoted as SAM (oracle) in the table, representing an optimal selection from multiple candidate masks. Our method, which employs prompt learning (SAM+PLM), consistently outperforms SAM, SAM-F, and SAM (oracle). SAM-F, utilizing learnable weights to combine SAM’s multiple masks into a final mask, often shows limited performance improvements. This approach, while useful for blending SAM’s multiple masks for objects having similar size, faces challenges in customized segmentation tasks. This is due to its inability to control the multiple mask output by SAM, constraining its adaptability for customized segmentation objectives. Moreover, incorporating the proposed PMM enhances the overall performance, except in cases of irregular shapes like earrings. This integration acts as an auxiliary task, serving two key roles: it prevents model overfitting on simple shapes and promotes extended training periods, thanks to its auxiliary nature enhancing point matching task. Thus, our method confirms the effectiveness of the PMM on the generalized model, trained on large-scale datasets, to align with user intentions. Notably, even with a fixed mask decoder and solely employing prompt learning, our method outperforms SAM (oracle) in performance.

Qualitative Results. As shown in Fig. 4, while the original SAM performs well in larger facial areas like ‘skin’, it struggles with accurately segmenting other facial parts such as the nose and ears. Despite these parts having distinct edges, their similar color and texture to the ‘skin’ make it challenging for SAM to segment them. SAM-F, through scale-aware fine-tuning, occasionally achieves more successful segmentation than SAM. However, SAM-F fundamentally does not modify the mask generation process due to its reliance on SAM’s mask decoder, thereby inheriting the same limitations in accurately segmenting facial parts where nuanced shape understanding is crucial. By contrast, our proposed method enables to customize the segmenter in desired forms using data collected by the user. Our approach allows the creation of a SAM model that reflects the user’s intent with only 0.25% (=1.6M/641M of SAM-ViT-H) of the parameters of the segmentation foundation model.

Refer to caption
Figure 5: Refinement results with PMM applied in test. From left to right: the initial mask; point adjustments: blue dots (initial boundary points), green dots (refined boundary points); and the reconstructed mask from refined points.

Effect of PMM. Our method shows an additional benefit when utilizing the PMM during testing. While PMM serves as an auxiliary task in training, its application in a two-step refinement process post-training can enhance segmentation masks. Initially, we extract contour points from a mask, which are then reprocessed through the PMM to yield refined contours and a more precise polygon mask. Fig. 5 illustrates these refinement results. The first column shows the initial masks from the mask decoder. The middle column visualizes the adjustment of points by the PMM, where blue dots denote the initial boundary points and green dots show the refined points. The third column shows the masks reconstructed with these refined points, leading to improved coverage and detail, especially in larger or initially vague areas like ‘skin’.

Refer to caption
Figure 6: Cross-model testing results. It shows the segmentation results when a model trained with a specific part is applied to the prompt where that part was not trained, such as using a nose-trained model to segment the facial skin area.

Cross-model Testing. Figure 6 shows the adaptability of our method when tested across various facial parts. The first column illustrates successful cross-application from a model trained on the right ear to segmenting the left ear, leveraging facial symmetry. In the second column, a model trained on the upper lip is applied to the lower lip, resulting in predictions mirroring the upper lip’s shape. This indicates the distinct shape characteristics of the upper and lower lips, such as the orientations of the mouth corners. In the third column, a nose-trained model applied to the skin predicts a nose-shaped mask. It shows that our approach can adapt the segmentation foundation model to produce user-specified shapes, even in the absence of clear delineations. This ability emphasizes the technique’s effectiveness in accommodating the diverse and specific characteristics of facial features.

5.3 Outdoor Banner Segmentation

We applied our method to segmenting outdoor banners, a task where users typically expect rectangular segmentation outcomes. The complexity of banners, with their text, images, and patterns, often challenges the conventional SAM, leading to imprecise segmentations. To tackle this, we created a dataset with minimal labeling, consisting of banner and background images sourced online. Banners were randomly attached to backgrounds, with affine transformations applied to mimic different camera perspectives. The dataset featured 3k3𝑘3k3 italic_k banner images matched with one of 4k4𝑘4k4 italic_k background options, and we conducted tests on 980980980980 synthesized images using a consistent validation set.

Refer to caption
(a)
Refer to caption
(b)
Figure 7: Qualitative results for outdoor banner segmentation: (a) SAM and (b) Proposed (SAM+PLM+PMM).

Figure 7 shows that SAM usually prioritizes text or patterns, resulting in inaccurate segmentations. Our approach, however, enhances SAM’s focus on the intended banner area, significantly improving segmentation in the presence of intricate patterns. For a detailed comparison of performance metrics, see Table 3.

5.4 License Plate Segmentation

As the third application of our method, we tackled license plate segmentation using the Kaggle Car License Plate dataset [1], which comprises 433433433433 images with bounding box annotations. This dataset was proportionately divided into training and validation sets with an 80:20 ratio, and annotations were converted to polygons to capture the varied shapes of license plates more accurately.

Table 3: Segmentation performance for outdoor banner and license plate.
  Method Outdoor Banner License Plate
  SAM 30.27 63.79
SAM-F 30.60 53.90
Proposed (SAM+PLM) 95.31 74.24
Proposed (SAM+PLM+PMM) 97.33 76.29
SAM (oracle) 93.84 81.98
Proposed (SAM+PLM+PMM) (oracle) 97.44 81.80
 
Refer to caption
(a)
Refer to caption
(b)
Figure 8: Qualitative results for license plate segmentation: (a) SAM and (b) Proposed (SAM+PLM+PMM).
Refer to caption
Figure 9: Visualization of IoU maps by input prompt position.
Refer to caption
Figure 10: Example of failure cases for outdoor banner.

Figure 8 reveals that SAM typically prioritizes text or patterns, which often leads to inaccurate segmentations. However, with our method applied, the segmentation model is modified to infer the license plate area in alignment with user intentions. For a comprehensive performance analysis, refer to Table 3.

6 Discussion

Prompt Sensitivity. To evaluate our model’s robustness to the input prompt’s location, we experimented with varying prompt positions and measured the IoU scores of the resulting masks against the GT. Figure 9 shows these findings, with the IoU map reflecting IoU scores relative to prompt positions. Our method robustly detects the banner area, accurately reflecting the user’s intent in selecting positions within the banner.

Failure Cases. As shown in Fig. 10, the proposed method may segment similar objects into one instance when they overlap or exist near the input prompt. This can be understood because the proposed method is trained that segmenting the object containing the input prompt as an instance is more important than instance-wise discrimination.

Dataset Dependency. The proposed method’s varying effectiveness across datasets is due to inherent differences: for banners and license plates with strong rectangular priors and clear edges, SAM performs well, limiting dramatic improvements by our method. In contrast, for less-defined shapes like facial parts, our method effectively customizes SAM, showing more significant improvements.

Sparse Prompt. For the practical applications, we initially focused on a single-point sparse prompt due to their efficiency and straightforwardness in customized segmentation tasks. As in [22], the proposed method can be applied with a bounding box prompt without loss of generality. However, we hypothesized that bounding boxes containing more user-defined prior information might not significantly benefit from customization as point prompts do.

Number of Parameters. For improved customized segmentation, we devise an additional PLM (ϕitalic-ϕ\phiitalic_ϕ) and PMM (φ𝜑\varphiitalic_φ) on the segmentation foundation model (i.e., SAM). Since only additional modules are trained on top of the SAM instead of extensive training of large model, efficient learning is possible. Table 4 shows the number of training parameters in the SAM model (i.e., image encoder, mask decoder) and the proposed modules. With the frozen training parameters in the SAM, we train 2.8M parameters in the training stage and 1.6M parameters are used for the inference. Therefore, we can train the customized instance segmentation model efficiently. Moreover, it offers significant flexibility, allowing for plug-and-play by either excluding this module for general instance segmentation or replacing the module learned for another task.

Table 4: Number of training parameters for the SAM and the proposed modules.
  Module Number of Parameters
  Image Encoder EIsubscript𝐸𝐼E_{I}italic_E start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT (ViT-H) 637M
Prompt Encoder EPsubscript𝐸𝑃E_{P}italic_E start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT 6.2K
Mask Decoder D𝐷Ditalic_D 4.1M
Prompt Learning Module (PLM) ϕitalic-ϕ\phiitalic_ϕ 1.6M
Point Matching Module (PMM) φ𝜑\varphiitalic_φ 1.2M
 

7 Conclusion

In this paper, we proposed a novel method to customize the segmentation foundation model via prompt learning for instance segmentation. In particular, to tackle the issue of prompt sensitivity, we designed the prompt learning module (PLM) that transforms input prompts within the embedding space. To enhance segmentation quality, we devised the point matching module (PMM) aligning the boundary points of the estimated mask with those of the GT. Through experiments in customized instance segmentation scenarios, we validated the efficacy of the proposed method. Furthermore, by training only PLM apart from the foundation model, our method can be used in a plug-and-play manner without compromising the generalization capability of the foundation model.

Acknowledgement

This work was supported by Institute of Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korean Government (MSIT) (No. 2014-3-00123, Development of High Performance Visual BigData Discovery Platform for Large-Scale Realtime Data Analysis and No. 2022-0-00124, Development of Artificial Intelligence Technology for Self-Improving Competency-Aware Learning Capabilities).

References

  • [1] Kaggle car license plate detection. https://www.kaggle.com/datasets/andrewmvd/car-plate-detection. Accessed: 2024-3-7.
  • Ba et al. [2016] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
  • Bai et al. [2023] Yunhao Bai, Duowen Chen, Qingli Li, Wei Shen, and Yan Wang. Bidirectional copy-paste for semi-supervised medical image segmentation. In CVPR, pages 11514–11524, 2023.
  • Bolya et al. [2019] Daniel Bolya, Chong Zhou, Fanyi Xiao, and Yong Jae Lee. Yolact: Real-time instance segmentation. In ICCV, pages 9157–9166, 2019.
  • Bommasani et al. [2021] Rishi Bommasani, Drew A Hudson, Ehsan Adeli, Russ Altman, Simran Arora, Sydney von Arx, Michael S Bernstein, Jeannette Bohg, Antoine Bosselut, Emma Brunskill, et al. On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2021.
  • Cao et al. [2023] Yunkang Cao, Xiaohao Xu, Chen Sun, Yuqi Cheng, Zongwei Du, Liang Gao, and Weiming Shen. Segment any anomaly without training via hybrid prompt regularization. arXiv preprint arXiv:2305.10724, 2023.
  • Cen et al. [2023] Jiazhong Cen, Zanwei Zhou, Jiemin Fang, Wei Shen, Lingxi Xie, Xiaopeng Zhang, and Qi Tian. Segment anything in 3d with nerfs. In NeurIPS, 2023.
  • Chen et al. [2023] Xinrun Chen, Chengliang Wang, Haojian Ning, and Shiying Li. Sam-octa: Prompting segment-anything for octa image segmentation. arXiv preprint arXiv:2310.07183, 2023.
  • Chen et al. [2022] Zhang Chen, Zhiqiang Tian, Jihua Zhu, Ce Li, and Shaoyi Du. C-cam: Causal cam for weakly supervised semantic segmentation on medical image. In CVPR, pages 11676–11685, 2022.
  • Cheng et al. [2023] Yangming Cheng, Liulei Li, Yuanyou Xu, Xiaodi Li, Zongxin Yang, Wenguan Wang, and Yi Yang. Segment and track anything. arXiv preprint arXiv:2305.06558, 2023.
  • Dong et al. [2021] Bin Dong, Fangao Zeng, Tiancai Wang, Xiangyu Zhang, and Yichen Wei. Solq: Segmenting objects by learning queries. In NeurIPS, pages 21898–21909, 2021.
  • Dosovitskiy et al. [2020] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2020.
  • Fang et al. [2021] Yuxin Fang, Shusheng Yang, Xinggang Wang, Yu Li, Chen Fang, Ying Shan, Bin Feng, and Wenyu Liu. Instances as queries. In ICCV, pages 6910–6919, 2021.
  • Feng et al. [2020] Di Feng, Christian Haase-Schütz, Lars Rosenbaum, Heinz Hertlein, Claudius Glaeser, Fabian Timm, Werner Wiesbeck, and Klaus Dietmayer. Deep multi-modal object detection and semantic segmentation for autonomous driving: Datasets, methods, and challenges. IEEE Trans. Intell. Transp. Syst., 22(3):1341–1360, 2020.
  • He et al. [2017] Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. Mask r-cnn. In ICCV, pages 2961–2969, 2017.
  • Hu et al. [2021a] Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021a.
  • Hu et al. [2021b] Jie Hu, Liujuan Cao, Yao Lu, ShengChuan Zhang, Yan Wang, Ke Li, Feiyue Huang, Ling Shao, and Rongrong Ji. Istr: End-to-end instance segmentation with transformers. arXiv preprint arXiv:2105.00637, 2021b.
  • Huang et al. [2019] Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, and Xinggang Wang. Mask scoring r-cnn. In CVPR, pages 6409–6418, 2019.
  • Jia et al. [2021] Chao Jia, Yinfei Yang, Ye Xia, Yi-Ting Chen, Zarana Parekh, Hieu Pham, Quoc Le, Yun-Hsuan Sung, Zhen Li, and Tom Duerig. Scaling up visual and vision-language representation learning with noisy text supervision. pages 4904–4916, 2021.
  • Jia et al. [2022] Menglin Jia, Luming Tang, Bor-Chun Chen, Claire Cardie, Serge Belongie, Bharath Hariharan, and Ser-Nam Lim. Visual prompt tuning. In ECCV, pages 709–727. Springer, 2022.
  • Ke et al. [2023] Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, and Fisher Yu. Segment anything in high quality. In NeurIPS, 2023.
  • Kirillov et al. [2023] Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C Berg, Wan-Yen Lo, Piotr Dollar, and Ross Girschick. Segment Anything. In ICCV, pages 4015–4026, 2023.
  • Lee et al. [2020] Cheng-Han Lee, Ziwei Liu, Lingyun Wu, and Ping Luo. Maskgan: Towards diverse and interactive facial image manipulation. In CVPR, 2020.
  • Lee and Park [2020] Youngwan Lee and Jongyoul Park. Centermask: Real-time anchor-free instance segmentation. In CVPR, pages 13906–13915, 2020.
  • Lester et al. [2021] Brian Lester, Rami Al-Rfou, and Noah Constant. The power of scale for parameter-efficient prompt tuning. In Emp. Meth. Nat. Lan. Proc., pages 3045–3059, 2021.
  • Li and Liang [2021] Xiang Lisa Li and Percy Liang. Prefix-tuning: Optimizing continuous prompts for generation. arXiv preprint arXiv:2101.00190, 2021.
  • Lin et al. [2014] Tsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr Dollár, and C Lawrence Zitnick. Microsoft coco: Common objects in context. In ECCV, pages 740–755, 2014.
  • Ling et al. [2021] Huan Ling, Karsten Kreis, Daiqing Li, Seung Wook Kim, Antonio Torralba, and Sanja Fidler. Editgan: High-precision semantic image editing. In NeurIPS, pages 16331–16345, 2021.
  • Liu et al. [2023a] Pengfei Liu, Weizhe Yuan, Jinlan Fu, Zhengbao Jiang, Hiroaki Hayashi, and Graham Neubig. Pre-train, prompt, and predict: A systematic survey of prompting methods in natural language processing. ACM Computing Surveys, 55(9):1–35, 2023a.
  • Liu et al. [2023b] Shilong Liu, Zhaoyang Zeng, Tianhe Ren, Feng Li, Hao Zhang, Jie Yang, Chunyuan Li, Jianwei Yang, Hang Su, Jun Zhu, et al. Grounding dino: Marrying dino with grounded pre-training for open-set object detection. arXiv preprint arXiv:2303.05499, 2023b.
  • Liu et al. [2021] Xiao Liu, Kaixuan Ji, Yicheng Fu, Weng Lam Tam, Zhengxiao Du, Zhilin Yang, and Jie Tang. P-tuning v2: Prompt tuning can be comparable to fine-tuning universally across scales and tasks. arXiv preprint arXiv:2110.07602, 2021.
  • Ma et al. [2024] Jun Ma, Yuting He, Feifei Li, Lin Han, Chenyu You, and Bo Wang. Segment anything in medical images. Nature Communications, 15(1):654, 2024.
  • Mazurowski et al. [2023] Maciej A Mazurowski, Haoyu Dong, Hanxue Gu, Jichen Yang, Nicholas Konz, and Yixin Zhang. Segment anything model for medical image analysis: an experimental study. Medical Image Analysis, 89:102918, 2023.
  • Radford et al. [2021] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual models from natural language supervision. pages 8748–8763, 2021.
  • Ren et al. [2015] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun. Faster r-cnn: Towards real-time object detection with region proposal networks. In NeurIPS, 2015.
  • Siam et al. [2018] Mennatullah Siam, Mostafa Gamal, Moemen Abdel-Razek, Senthil Yogamani, Martin Jagersand, and Hong Zhang. A comparative study of real-time semantic segmentation for autonomous driving. In CVPRW, pages 587–597, 2018.
  • Tang et al. [2023] Lv Tang, Haoke Xiao, and Bo Li. Can sam segment anything? when sam meets camouflaged object detection. arXiv preprint arXiv:2304.04709, 2023.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NeurIPS, 2017.
  • Wang et al. [2023] Di Wang, Jing Zhang, Bo Du, Dacheng Tao, and Liangpei Zhang. Scaling-up remote sensing segmentation dataset with segment anything model. arXiv preprint arXiv:2305.02034, 2023.
  • Wang et al. [2020] Xinlong Wang, Tao Kong, Chunhua Shen, Yuning Jiang, and Lei Li. Solo: Segmenting objects by locations. In ECCV, pages 649–665, 2020.
  • Wu et al. [2023] Junde Wu, Rao Fu, Huihui Fang, Yuanpei Liu, Zhaowei Wang, Yanwu Xu, Yueming Jin, and Tal Arbel. Medical sam adapter: Adapting segment anything model for medical image segmentation. arXiv preprint arXiv:2304.12620, 2023.
  • Yang et al. [2023] Jinyu Yang, Mingqi Gao, Zhe Li, Shang Gao, Fangjing Wang, and Feng Zheng. Track anything: Segment anything meets videos. arXiv preprint arXiv:2304.11968, 2023.
  • Yu et al. [2023] Tao Yu, Runseng Feng, Ruoyu Feng, Jinming Liu, Xin Jin, Wenjun Zeng, and Zhibo Chen. Inpaint anything: Segment anything meets image inpainting. arXiv preprint arXiv:2304.06790, 2023.
  • Zhang et al. [2020] Jianfu Zhang, Peiming Yang, Wentao Wang, Yan Hong, and Liqing Zhang. Image editing via segmentation guided self-attention network. IEEE Sign. Process. Letters, 27:1605–1609, 2020.
  • Zhang and Zhuang [2022] Ke Zhang and Xiahai Zhuang. Cyclemix: A holistic strategy for medical image segmentation from scribble supervision. In CVPR, pages 11656–11665, 2022.
  • Zhang et al. [2023a] Renrui Zhang, Zhengkai Jiang, Ziyu Guo, Shilin Yan, Junting Pan, Hao Dong, Peng Gao, and Hongsheng Li. Personalize segment anything model with one shot. arXiv preprint arXiv:2305.03048, 2023a.
  • Zhang and Metaxas [2023] Shaoting Zhang and Dimitris Metaxas. On the challenges and perspectives of foundation models for medical image analysis. arXiv preprint arXiv:2306.05705, 2023.
  • Zhang et al. [2023b] Shi-Xue Zhang, Chun Yang, Xiaobin Zhu, and Xu-Cheng Yin. Arbitrary shape text detection via boundary transformer. IEEE TMM, 2023b.
  • Zhou et al. [2023] Tao Zhou, Yizhe Zhang, Yi Zhou, Ye Wu, and Chen Gong. Can sam segment polyps? arXiv preprint arXiv:2304.07583, 2023.