Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

Denoising Diffusions in Latent Space for Medical Image Segmentation

Fahim Ahmed Zaman, Mathews Jacob, Amanda Chang, Kan Liu, Milan Sonka and Xiaodong Wu Fahim Ahmed Zaman, Mathews Jacob, Milan Sonka and Xiaodong Wu are with the Department of Electrical and Computer Engineering, University of Iowa, Iowa City, IA 52242, USA (e-mails: {fahim-zaman, mathews-jacob, milan-sonka, xiaodong-wu}@uiowa.edu)Amanda Chang is with the Division of Cardiology, Department of Internal Medicine, University of Iowa, Iowa City, IA 52242, USA (e-mails: amanda-chang@uiowa.edu)Kan Liu is with the Division of Cardiology and Heart and Vascular Center, School of Medicine, Washington University in St Louis, St. Louis, MO 63130, USA (e-mails: kanl@wustl.edu)
Abstract

Diffusion models (DPMs) have demonstrated remarkable performance in image generation, often times outperforming other generative models. Since their introduction, the powerful noise-to-image denoising pipeline has been extended to various discriminative tasks, including image segmentation. In case of medical imaging, often times the images are large 3D scans, where segmenting one image using DPMs become extremely inefficient due to large memory consumption and time consuming iterative sampling process. In this work, we propose a novel conditional generative modeling framework (LDSeg) that performs diffusion in latent space for medical image segmentation. Our proposed framework leverages the learned inherent low-dimensional latent distribution of the target object shapes and source image embeddings. The conditional diffusion in latent space not only ensures accurate n-D image segmentation for multi-label objects, but also mitigates the major underlying problems of the traditional DPM based segmentation: (1) large memory consumption, (2) time consuming sampling process and (3) unnatural noise injection in forward/reverse process. LDSeg achieved state-of-the-art segmentation accuracy on three medical image datasets with different imaging modalities. Furthermore, we show that our proposed model is significantly more robust to noises, compared to the traditional deterministic segmentation models, which can be potential in solving the domain shift problems in the medical imaging domain. Codes are available at: https://github.com/LDSeg/LDSeg.

I Introduction

In the field of medical imaging, image segmentation is a crucial step for identifying and monitoring disease related pathologies. The qualitative and quantitative measures of segmented objects also guide clinical decisions in treatment, surgical planning, target therapy by evaluating the progression of diseases [1]. The traditional deep-learning (DL) based segmentation models have achieved impressive segmentation accuracy on various imaging modalities of the medical imaging domain, that often time match/outperform field level experts [2, 3]. These DL based models mostly include convolutional neural networks (CNNs), vision transformers (ViTs) and graph-based models, which are generally trained end-to-end in a discriminative manner. Recently, generative models have also emerged as powerful image segmentation tools which take advantage of learning the underlying statistics of target objects, conditioned on the source image. These conditional generative models include generative adversarial networks (GANs) and the diffusion probabilistic models (DPMs).

In the computer vision applications, the DPMs [4, 5, 6, 7] have achieved remarkable results for image generation, outperforming other generative models [8]. But adapting DPMs in medical image segmentation is fairly challenging due to the complex tissue structures, noisy image acquisition and large image size of the medical image datasets. A lot of research has been driven towards adapting the DPMs for medical image segmentation [9, 10, 11, 12]. The standard DPMs have two major components, a forward process that perturbs the image with added Gaussian noise, and a reverse process that starts with a Gaussian noise and iteratively denoise the image to generate a clean image of original data distribution. The denoiser is trained with noisy images for different noise variances where the objective is to learn the noise distribution of the transitional states of the forward process. The DPMs used for segmentation differs from the ones with the image generation, such that the forward/reverse process includes the segmentation mask instead of the source image. The source image is generally used as a condition to the denoiser. The final objective of the reverse process is to sample a segmentation mask from the original mask data distribution with source image as a condition.

Perturbing the segmentation masks with directly adding Gaussian noise creates unnatural distortion in the underlying distribution, as the masks have very few modes (depending on the semantic classes present). As a result, training the denoiser becomes challenging due to the absence of smooth transition among the various modes. Additional thresholding is needed to obtain the final segmentation mask that can get filled with hole like features [11] due to the high frequency noises. Wu et al. proposed to use frequency parser blocks in the hidden layers of the denoiser to modulate the high frequency noises [9], but it does not guarantee clean result after sampling and may need post-processing. Bogensperger et al. proposed to transform the discrete segmentation mask to signed distance function (SDF-DDPM), where each pixel represents the signed euclidean distance from the closest object boundary [11]. A limitation of this approach is the distance map for the multi-class images is ambiguous. Zaman et al. proposed to re-parameterize the segmentation masks to a graph structure which guarantees natural perturbation on the continuous surface distances on the graph column [13]. This model also suffers from the multi-class mask representation problem as surface positions for different objects become ambiguous. These indicate that a proper re-parameterization technique is needed that can be implemented for multi-class objects simultaneously and guarantees smooth state transitions.

Another key challenge of the DPMs is to reduce the time consuming iterative sampling process. Various methods have been proposed to reduce the sampling steps for the natural image generation [14, 15, 16, 17]. Medical image datasets, often times including large 3D scans per subject, introduces extra burden on the GPU/CPU memory, hence increases the overall sampling time for generating quality segmentation results using DPMs, making them extremely inefficient. Recently, DPMs have been proposed that leverages the learned latent space for faster training/sampling pipeline for natural image generation and segmentation [18, 19]. PNVR et al. proposed to learn latent space of the source images, then a denoiser is trained with text embeddings as a condition for image generation. Finally, a Unet shaped autoencoder is trained to segment the image based on the text embeddings condition, as well as the leveraging the diffusion features from the denoiser through attention mechanism [20]. To incorporate the advantages of latent space diffusion techniques, in this work, we propose a novel conditional diffusion based generative framework (LDSeg) for medical image segmentation, that leverages the learned uni-variate Gaussian latent representation of the target object shapes as well as the source image embeddings for accurate segmentation. The contributions of this work can be summarized, as follows:

  1. 1.

    To the best of our knowledge, this is the first work to leverage the learned uni-variate Gaussian latent space of the object shapes for proper conditioning on the denoiser for faster sampling process.

  2. 2.

    The continuous latent space allows direct incorporation of standard diffusion techniques for forward and reverse process, solving the unnatural noise injection on the labeled segmentation mask for multi-class object segmentation.

  3. 3.

    The diffusion in latent space ensures less memory consumption and faster training/sampling even for large 3D medical scans.

  4. 4.

    The model is significantly more robust to noises in the source images compared to the deterministic segmentation models due to the low-dimensional image embeddings, which mitigates the segmentation problem of images with noisy acquisition, as well as paves the way for solving the domain shift problems in the medical imaging domain.

II Background

II-A Denoising Diffusion Probabilistic Model (DDPM)

DDPM starts from a sample in random distribution and reconstructs original data via a gradual denoising process. This denoising reverse process can be modeled as pθ(x0:T)subscript𝑝𝜃subscript𝑥:0𝑇p_{\theta}(x_{0:T})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 : italic_T end_POSTSUBSCRIPT ), which is a Markov chain with learned Gaussian transitions starting at p(xT)=𝒩(xT;0,I)𝑝subscript𝑥𝑇𝒩subscript𝑥𝑇0𝐼p(x_{T})=\mathcal{N}(x_{T};0,I)italic_p ( italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) = caligraphic_N ( italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ; 0 , italic_I ):

pθ(x0:T)p(xT)t=1Tpθ(xt1xt)subscript𝑝𝜃subscript𝑥:0𝑇𝑝subscript𝑥𝑇superscriptsubscriptproduct𝑡1𝑇subscript𝑝𝜃conditionalsubscript𝑥𝑡1subscript𝑥𝑡p_{\theta}(x_{0:T})\coloneqq p(x_{T})\prod_{t=1}^{T}p_{\theta}(x_{t-1}\mid x_{% t})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 : italic_T end_POSTSUBSCRIPT ) ≔ italic_p ( italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ∏ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (1)
pθ(xt1xt)𝒩(xt1;μθ(xt,t),Σθ(xt,t))subscript𝑝𝜃conditionalsubscript𝑥𝑡1subscript𝑥𝑡𝒩subscript𝑥𝑡1subscript𝜇𝜃subscript𝑥𝑡𝑡subscriptΣ𝜃subscript𝑥𝑡𝑡p_{\theta}(x_{t-1}\mid x_{t})\coloneqq\mathcal{N}(x_{t-1};\mu_{\theta}(x_{t},t% ),\Sigma_{\theta}(x_{t},t))italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ≔ caligraphic_N ( italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ; italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) , roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) (2)

where x0q(x0)similar-tosubscript𝑥0𝑞subscript𝑥0x_{0}\sim q(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is a sample from real data distributions, x1xTsubscript𝑥1subscript𝑥𝑇x_{1}\cdots x_{T}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋯ italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT are transitional states from timesteps t=1,,T𝑡1𝑇t=1,\cdots,Titalic_t = 1 , ⋯ , italic_T.

The forward process in the diffusion models is also a Markov chain, which gradually adds noise to the image. Given data x0q(x0)similar-tosubscript𝑥0𝑞𝑥0x_{0}\sim q(x0)italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q ( italic_x 0 ) sampled from the real distribution, the forward process at time t[1,T]𝑡1𝑇t\in[1,T]italic_t ∈ [ 1 , italic_T ] can be defined as q(xtxt1)𝑞conditionalsubscript𝑥𝑡subscript𝑥𝑡1q(x_{t}\mid x_{t-1})italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ), where Gaussian noise is gradually added given a noise variance schedule βt[β1,βT]subscript𝛽𝑡subscript𝛽1subscript𝛽𝑇\beta_{t}\in[\beta_{1},\beta_{T}]italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ [ italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ]:

q(xtxt1)=𝒩(xt;1βtxt1,βtI)𝑞conditionalsubscript𝑥𝑡subscript𝑥𝑡1𝒩subscript𝑥𝑡1subscript𝛽𝑡subscript𝑥𝑡1subscript𝛽𝑡𝐼q(x_{t}\mid x_{t-1})=\mathcal{N}(x_{t};\sqrt{1-\beta_{t}}x_{t-1},\beta_{t}I)italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; square-root start_ARG 1 - italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_I ) (3)

The choice of Gaussian provides a close-form solution to generate a transitional state xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT using,

xt=α¯x0+1α¯ϵsubscript𝑥𝑡¯𝛼subscript𝑥01¯𝛼italic-ϵx_{t}=\sqrt{\bar{\alpha}}x_{0}+\sqrt{1-\bar{\alpha}}\epsilonitalic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = square-root start_ARG over¯ start_ARG italic_α end_ARG end_ARG italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + square-root start_ARG 1 - over¯ start_ARG italic_α end_ARG end_ARG italic_ϵ (4)

where αt=1βtsubscript𝛼𝑡1subscript𝛽𝑡\alpha_{t}=1-\beta_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1 - italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, αt¯=i=1tαi¯subscript𝛼𝑡superscriptsubscriptproduct𝑖1𝑡subscript𝛼𝑖\bar{\alpha_{t}}=\prod_{i=1}^{t}\alpha_{i}over¯ start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and ϵ𝒩(0,I)similar-toitalic-ϵ𝒩0𝐼\epsilon\sim\mathcal{N}(0,I)italic_ϵ ∼ caligraphic_N ( 0 , italic_I ). The training is usually performed by optimizing the variational bound on the negative log likelihood of pθ(x0)subscript𝑝𝜃subscript𝑥0p_{\theta}(x_{0})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ):

𝔼q[logpθ(x0:T)q(x1:Tx0)]𝔼[logpθ(x0)]subscript𝔼𝑞delimited-[]subscript𝑝𝜃subscript𝑥:0𝑇𝑞conditionalsubscript𝑥:1𝑇subscript𝑥0𝔼delimited-[]subscript𝑝𝜃subscript𝑥0\mathcal{L}\coloneqq\mathbb{E}_{q}\left[-\log\frac{p_{\theta}(x_{0:T})}{q(x_{1% :T}\mid x_{0})}\right]\geq\mathbb{E}\left[-\log p_{\theta}(x_{0})\right]caligraphic_L ≔ blackboard_E start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT [ - roman_log divide start_ARG italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 : italic_T end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( italic_x start_POSTSUBSCRIPT 1 : italic_T end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ] ≥ blackboard_E [ - roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ]

However, with re-parameterization, Ho et al. [4] simplified the training objective and proposed to train on the variant of the variational bound which is beneficial to sample quality and simpler to implement,

DDPM𝔼t,x0,ϵ[ϵϵθ(xt,t)2]subscript𝐷𝐷𝑃𝑀subscript𝔼𝑡subscript𝑥0italic-ϵdelimited-[]superscriptnormitalic-ϵsubscriptitalic-ϵ𝜃subscript𝑥𝑡𝑡2\mathcal{L}_{DDPM}\coloneqq\mathbb{E}_{t,x_{0},\epsilon}[\|\epsilon-\epsilon_{% \theta}(x_{t},t)\|^{2}]caligraphic_L start_POSTSUBSCRIPT italic_D italic_D italic_P italic_M end_POSTSUBSCRIPT ≔ blackboard_E start_POSTSUBSCRIPT italic_t , italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϵ end_POSTSUBSCRIPT [ ∥ italic_ϵ - italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (5)

where ϵθsubscriptitalic-ϵ𝜃\epsilon_{\theta}italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is a function approximator intended to predict ϵitalic-ϵ\epsilonitalic_ϵ from xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by a trained denoiser. With a trained denoiser, the data can be generated with the reverse process by iterating through t=T,,1𝑡𝑇1t=T,\cdots,1italic_t = italic_T , ⋯ , 1. Starting from xT𝒩(0,I)similar-tosubscript𝑥𝑇𝒩0𝐼x_{T}\sim\mathcal{N}(0,I)italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , italic_I ), the transitional states can be obtained by,

xt1=1αt(xtβt1αt¯ϵθ(xt,t))+σtzsubscript𝑥𝑡11subscript𝛼𝑡subscript𝑥𝑡subscript𝛽𝑡1¯subscript𝛼𝑡subscriptitalic-ϵ𝜃subscript𝑥𝑡𝑡subscript𝜎𝑡𝑧x_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(x_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{% \alpha_{t}}}}\epsilon_{\theta}(x_{t},t)\right)+\sigma_{t}zitalic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG 1 - over¯ start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_ARG end_ARG italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_z (6)

where σtsubscript𝜎𝑡\sigma_{t}italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the noise variance of timestep t𝑡titalic_t and z𝒩(0,I)similar-to𝑧𝒩0𝐼z\sim\mathcal{N}(0,I)italic_z ∼ caligraphic_N ( 0 , italic_I ).

III Method

Refer to caption
Figure 1: Proposed LDSeg model. Step 1: An autoencoder is used to learn the low dimensional latent representation m0=enc(M)subscript𝑚0𝑒𝑛𝑐𝑀m_{0}=enc(M)italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_e italic_n italic_c ( italic_M ) for a given input ground truth label/mask image M𝑀Mitalic_M by learning the joint distribution lθ(m0,M¯M)subscript𝑙𝜃subscript𝑚0conditional¯𝑀𝑀l_{\theta}(m_{0},\bar{M}\mid M)italic_l start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , over¯ start_ARG italic_M end_ARG ∣ italic_M ), where M¯¯𝑀\bar{M}over¯ start_ARG italic_M end_ARG is the reconstructed mask image. Step 2: A conditional denoiser is trained by learning the joint distribution of dθ(mt1|mt,I,t)subscript𝑑𝜃conditionalsubscript𝑚𝑡1subscript𝑚𝑡𝐼𝑡d_{\theta}(m_{t-1}|m_{t},I,t)italic_d start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT | italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_I , italic_t ) for the time step t=1,,T𝑡1𝑇t=1,\dotsc,Titalic_t = 1 , … , italic_T, where T𝑇Titalic_T is the total number of diffusion steps, mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the perturbed latent representation at time step t𝑡titalic_t and I𝐼Iitalic_I is the source image. The conditional image I𝐼Iitalic_I is embedded with a mask encoder and added with the degraded latent representation mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. 𝒢(.)\mathcal{G}(.)caligraphic_G ( . ) is the Gaussian diffusion block that implements forward diffusion for m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT at time step t𝑡titalic_t. In the inference phase, I𝐼Iitalic_I is used to obtain the segmented image S𝑆Sitalic_S.

The proposed LDSeg framework contains two major components: 1) A mask autoencoder: The mask autoencoder is used to learn the low dimensional latent representation of the target object shapes. 2) A conditional denoiser: The conditional denoiser learns the noise distribution for each time step conditioned on image embedding from the source image. Image embedding is learned using an Image encoder. The model workflow is shown in Fig. 1

III-A Mask Autoencoder

Injecting Gaussian noise on segmentation labels is unnatural, as the label/mask image has only few modes (number of object classes). It is also difficult for a denoiser to learn the intermediate noise distributions when the data distribution is a combinations of multivariate Gaussians. We propose to mitigate this inherent problem by learning a uni-variate low-dimensional Gaussian representation of the label images. In other words, we want to learn a transfer function h(.)h(.)italic_h ( . ) that projects the input masks to a latent space having an uni-variate Gaussian distribution. We also want to learn the inverse function of h(.)h(.)italic_h ( . ) that reconstruct the input masks from their latent space representations. For this purpose, we propose to use a simple Res-Unet[21] shaped autoencoder without skip connections. The encoder learns enc(.)h(.)enc(.)\sim h(.)italic_e italic_n italic_c ( . ) ∼ italic_h ( . ) and the decoder learns dec(.)h1(.)dec(.)\sim h^{-1}(.)italic_d italic_e italic_c ( . ) ∼ italic_h start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( . ). Assume Mpdata(M)similar-to𝑀subscript𝑝𝑑𝑎𝑡𝑎𝑀M\sim p_{data}(M)italic_M ∼ italic_p start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT ( italic_M ) is a ground truth mask/label image. Then the latent representation m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and reconstruction M¯¯𝑀\bar{M}over¯ start_ARG italic_M end_ARG can be obtained by,

m0=enc(M),M¯=dec(m0)formulae-sequencesubscript𝑚0𝑒𝑛𝑐𝑀¯𝑀𝑑𝑒𝑐subscript𝑚0m_{0}=enc(M),\;\;\;\;\;\;\bar{M}=dec(m_{0})italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_e italic_n italic_c ( italic_M ) , over¯ start_ARG italic_M end_ARG = italic_d italic_e italic_c ( italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) (7)

Our objective is to learn the joint distribution lθ(m0,M¯M)subscript𝑙𝜃subscript𝑚0conditional¯𝑀𝑀l_{\theta}(m_{0},\bar{M}\mid M)italic_l start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , over¯ start_ARG italic_M end_ARG ∣ italic_M ). The loss function of the autoencoder is the multi-class cross entropy loss,

ae1Ni=1Nj=1C(yti,jlog(ypi,j))subscript𝑎𝑒1𝑁superscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝐶𝑦subscript𝑡𝑖𝑗𝑦subscript𝑝𝑖𝑗\mathcal{L}_{ae}\coloneqq-\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{C}(yt_{i,j}% \cdot\log(yp_{i,j}))caligraphic_L start_POSTSUBSCRIPT italic_a italic_e end_POSTSUBSCRIPT ≔ - divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT ( italic_y italic_t start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ⋅ roman_log ( italic_y italic_p start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ) ) (8)

where N𝑁Nitalic_N is the number of samples, C𝐶Citalic_C is the number of classes, yti,j𝑦subscript𝑡𝑖𝑗yt_{i,j}italic_y italic_t start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT is the true labels for class j𝑗jitalic_j for instance i𝑖iitalic_i and ypi,j𝑦subscript𝑝𝑖𝑗yp_{i,j}italic_y italic_p start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT is the predicted probability for class j𝑗jitalic_j for instance i𝑖iitalic_i. The final layer of the encoder is a layer-normalization layer, which ensures the latent representation m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is a uni-variate zero mean Gaussian. Essentially, the mask encoder learns the low-dimensional latent representation of the object shapes of the mask images, that can be reconstructed close to its original form using mask decoder.

III-B Conditional Denoiser (CD)

Refer to caption
Figure 2: A sample GlaS data is used to demonstate the forward and the reverse processes. In the forward process (top row), low-dimensional latent representation m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is first obtained from the ground truth mask image using the trained mask encoder. Then, Gaussian noise is gradually injected for timestep t=1,,T𝑡1𝑇t=1,\dotsc,Titalic_t = 1 , … , italic_T, given noise variance schedules of β𝛽\betaitalic_β, where ϵ𝒩(0,I)similar-toitalic-ϵ𝒩0I\epsilon\sim\mathcal{N}(\mathrm{0,I})italic_ϵ ∼ caligraphic_N ( 0 , roman_I ). At timestep T𝑇Titalic_T, mTsubscript𝑚𝑇m_{T}italic_m start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT converges to 𝒩(0,I)𝒩0I\mathcal{N}(\mathrm{0,I})caligraphic_N ( 0 , roman_I ). In the start of the reverse process (bottom row), m~Tsubscript~𝑚𝑇\tilde{m}_{T}over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is sampled from 𝒩(0,I)𝒩0I\mathcal{N}(\mathrm{0,I})caligraphic_N ( 0 , roman_I ). Then the conditional denoiser is used iteratively for timestep t=T,,1𝑡𝑇1t=T,\dotsc,1italic_t = italic_T , … , 1 with the input image I𝐼Iitalic_I as the condition. At the end of the reverse process, segmentation is obtained from m~0subscript~𝑚0\tilde{m}_{0}over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT using the trained mask decoder.

A standard denoiser of DPMs has two inputs, a noisy version of the input image and its corresponding timestep. For segmentation, the denoiser needs additional conditioning. The condition can be the source image [11, 9], or a text indicating the target object [20]. We propose to use image embedding as the condition to the denoiser. The image embedding is a low-dimensional latent representation of the source image which is learned using an image encoder having the similar architecture as the mask encoder, except it does not have a layer normalization layer at the end. The latent image embedding is concatenated with the noisy latent representation of the mask and used as a two channel input to the denoiser, along with timestep t𝑡titalic_t. In the forward process, a Gaussian block 𝒢𝒢\mathcal{G}caligraphic_G is used to produce the noisy mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT for timestep t𝑡titalic_t, given m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and noise variance schedule parameters α𝛼\alphaitalic_α, β𝛽\betaitalic_β [4, 14]. An example of forward process is shown in Fig. 2 (top row).

The denoiser has a standard Unet shape with time-embeddings and self attention layers. Specifically, we have adapted the denoiser architecture from [4]. The image encoder and the denoiser are trained together, where our objective is to learn the transitional latent state distributions dθ(mt1|mt,I,t)subscript𝑑𝜃conditionalsubscript𝑚𝑡1subscript𝑚𝑡𝐼𝑡d_{\theta}(m_{t-1}|m_{t},I,t)italic_d start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT | italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_I , italic_t ) for timestep t=1,,T𝑡1𝑇t=1,\dotsc,Titalic_t = 1 , … , italic_T, given source image I𝐼Iitalic_I. Here, T𝑇Titalic_T is the final timestep and mT𝒩(0,I)similar-tosubscript𝑚𝑇𝒩0Im_{T}\sim\mathcal{N}(\mathrm{0,I})italic_m start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , roman_I ). The conditional denoiser is trained by minimizing the following objective,

cd𝔼t,m0,ϵ[ϵϵθ(α¯m0+1α¯ϵ,I,t)2]subscript𝑐𝑑subscript𝔼𝑡subscript𝑚0italic-ϵdelimited-[]superscriptnormitalic-ϵsubscriptitalic-ϵ𝜃¯𝛼subscript𝑚01¯𝛼italic-ϵ𝐼𝑡2\mathcal{L}_{cd}\coloneqq\mathbb{E}_{t,m_{0},\epsilon}[\|\epsilon-\epsilon_{% \theta}(\sqrt{\bar{\alpha}}m_{0}+\sqrt{1-\bar{\alpha}}\epsilon,I,t)\|^{2}]caligraphic_L start_POSTSUBSCRIPT italic_c italic_d end_POSTSUBSCRIPT ≔ blackboard_E start_POSTSUBSCRIPT italic_t , italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϵ end_POSTSUBSCRIPT [ ∥ italic_ϵ - italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( square-root start_ARG over¯ start_ARG italic_α end_ARG end_ARG italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + square-root start_ARG 1 - over¯ start_ARG italic_α end_ARG end_ARG italic_ϵ , italic_I , italic_t ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (9)

Training algorithm for conditional denoiser with a trained mask autoencoder is shown in Algorithm 1.

III-C Reverse Process For Segmentation

As the image encoder is independent to the denoiser, we only need to obtain the conditional image embedding at the start of the reverse process. In the reverse process, our main objective is to generate latent representation m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, conditioned on the image embedding. Like the other image generation tasks of DPMs, we start with a Gaussian 𝒩(0,I)𝒩0I\mathcal{N}(\mathrm{0,I})caligraphic_N ( 0 , roman_I ) as the latent mask representation m~Tsubscript~𝑚𝑇\tilde{m}_{T}over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT at timestep T𝑇Titalic_T. Then the denoiser is iterated for t=T,,1𝑡𝑇1t=T,\dotsc,1italic_t = italic_T , … , 1. At the end of the iteration, we obtain m~0subscript~𝑚0\tilde{m}_{0}over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT from the denoiser, which is used as an input to the trained mask decoder to get the final segmentation S=dec(m~0)𝑆𝑑𝑒𝑐subscript~𝑚0S=dec(\tilde{m}_{0})italic_S = italic_d italic_e italic_c ( over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). An example of reverse process is shown in Fig. 2 (bottom row). Sampling algorithm for segmentation using the trained CD and mask autoencoder is shown in Algorithm 2.

Algorithm 1 Training CD
1:repeat
2:  I,Mqdata(I,M)similar-to𝐼𝑀subscript𝑞𝑑𝑎𝑡𝑎𝐼𝑀I,M\sim q_{data}(I,M)italic_I , italic_M ∼ italic_q start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT ( italic_I , italic_M )
3:  m0=enc(M)subscript𝑚0𝑒𝑛𝑐𝑀m_{0}=enc(M)italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_e italic_n italic_c ( italic_M )
4:  tUniform({1,,T})similar-to𝑡Uniform1𝑇t\sim\mathrm{Uniform}(\{1,\dotsc,T\})italic_t ∼ roman_Uniform ( { 1 , … , italic_T } )
5:  ϵ𝒩(0,I)similar-toitalic-ϵ𝒩0I\epsilon\sim\mathcal{N}(\mathrm{0,I})italic_ϵ ∼ caligraphic_N ( 0 , roman_I )
6:  Take gradient descent step on
7:θϵϵθ(α¯tm0+1α¯tϵ,I,t)2subscript𝜃superscriptnormitalic-ϵsubscriptitalic-ϵ𝜃subscript¯𝛼𝑡subscript𝑚01subscript¯𝛼𝑡italic-ϵ𝐼𝑡2\;\;\nabla_{\theta}\left\|\epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha}_{t}}m_% {0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon,I,t)\right\|^{2}∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∥ italic_ϵ - italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( square-root start_ARG over¯ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + square-root start_ARG 1 - over¯ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_ϵ , italic_I , italic_t ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
8:until converged
Algorithm 2 Sampling for segmentation
1:Iqdata(I)similar-to𝐼subscript𝑞𝑑𝑎𝑡𝑎𝐼I\sim q_{data}(I)italic_I ∼ italic_q start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT ( italic_I ), m~T𝒩(0,I)similar-tosubscript~𝑚𝑇𝒩0I\tilde{m}_{T}\sim\mathcal{N}(\mathrm{0,I})over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , roman_I )
2:for t=T,,1𝑡𝑇1t=T,\dotsc,1italic_t = italic_T , … , 1 do
3:  z𝒩(0,I)similar-to𝑧𝒩0Iz\sim\mathcal{N}(\mathrm{0,I})italic_z ∼ caligraphic_N ( 0 , roman_I ) if t>1𝑡1t>1italic_t > 1, else z=0𝑧0z=0italic_z = 0
4:  m~t1=1αt(m~tβt1α¯tϵθ(m~t,I,t))+σtzsubscript~𝑚𝑡11subscript𝛼𝑡subscript~𝑚𝑡subscript𝛽𝑡1subscript¯𝛼𝑡subscriptitalic-ϵ𝜃subscript~𝑚𝑡𝐼𝑡subscript𝜎𝑡𝑧\tilde{m}_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(\tilde{m}_{t}-\frac{\beta_{t}% }{\sqrt{1-\bar{\alpha}_{t}}}\epsilon_{\theta}(\tilde{m}_{t},I,t)\right)+\sigma% _{t}zover~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_ARG ( over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG 1 - over¯ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_ARG italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_I , italic_t ) ) + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_z
5:end for
6:S=dec(m~0)𝑆𝑑𝑒𝑐subscript~𝑚0S=dec(\tilde{m}_{0})italic_S = italic_d italic_e italic_c ( over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
7:return S𝑆Sitalic_S

IV Experiments

IV-A Dataset

We have used 3 datasets to demonstrate the effectiveness of LDSeg:

  1. 1.

    Echo [22] is a 2D+t echocardiogram video dataset with the standard apical 4-chamber left ventricular (LV) focused view. The dataset contains 65 echocardiogram videos (2230 still-frame gray-scale images). The left ventricles (LV) and the left atriums (LA) were manually traced fully by an expert.

  2. 2.

    GlaS [23] is a publicly available 2D histopathology dataset of Hematoxylin and Eosin (H&E) stained slides, acquired by a team of pathologists at the University Hospitals Coventry and Warwickshire, UK. The training set contains 37 benign and 48 malignant images, whereas the test set contains 37 benign and 43 malignant images.

  3. 3.

    Knee (https://data-archive.nimh.nih.gov/oai/) is a publicly available 3D MRI dataset. The dataset contains randomly selected 987 3D MRI scans from 244 patients on different time points. Focused volumetric regions with an image size of 160×104×256160104256160\times 104\times 256160 × 104 × 256 around the FC and TC joint are used as region of interest (ROI). The femur cartilage with bone (FC) and tibia cartilage with bone (TC) are segmented by an automatic segmentation algorithm and validated/edited by an expert.

IV-B Experimental Setup

The mask encoder with Res-Unet architecture has several down-sampling layers that determines how much low-dimensional project of the mask image we want. We experimented with different down-sampling with the mask autoencoder and chose 4444 down-sampling layers, as it produced best results for all the dataset. The image size for Echo, GlaS and Knee data were resized to 256×384256384256\times 384256 × 384, 256×256256256256\times 256256 × 256 and 128×128×256128128256128\times 128\times 256128 × 128 × 256, respectively. Hence, the size of the low-dimensional h(m)𝑚h(m)italic_h ( italic_m ) for Echo, GlaS and Knee data are 16×24162416\times 2416 × 24, 16×16161616\times 1616 × 16 and 8×8×1688168\times 8\times 168 × 8 × 16 respectively. We observed that these are the optimal sizes as further down-sampling reduces the mask deocder accuracy, whereas less down-sampling reduces conditional denoiser accuracy as search space gets bigger for learning noise distributions. Image encoder is also a simple Res-Unet shaped autoencoder that produces low-dimensional image embedding having the same size of h(m)𝑚h(m)italic_h ( italic_m ). Exponentially decayed learning rates were used to train the models with 102superscript10210^{-2}10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT and 103superscript10310^{-3}10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT as the initial learning rates for the mask autoencoder and conditional denoiser, respectively. We employed 80:20:802080:2080 : 20 split for the training and testing, and 90:10:901090:1090 : 10 split for the training and validation for Echo and Knee dataset for training the models. The conditional denoiser is trained for 1000100010001000 epochs with batchsize 4444 and noise step t𝑡titalic_t is an integer randomly sampled from 1111 to 1000100010001000 for each batch. NVIDIA A100-SXM4 (80GB) GPU was used for training, whereas the AMD EPYC 7413 24-Core Processor was the CPU.

V Results

V-A Segmentation Accuracy

TABLE I: Quantitative results for Echo data segmentation.
Method DSC \uparrow IoU \uparrow
LV LA LV+LA LV LA LV+LA
U-net [24] 0.860.860.860.86 0.750.750.750.75 0.830.830.830.83 0.770.770.770.77 0.620.620.620.62 0.720.720.720.72
V-net [25] 0.930.93\mathbf{0.93}bold_0.93 0.810.810.810.81 0.900.900.900.90 0.870.87\mathbf{0.87}bold_0.87 0.710.710.710.71 0.830.830.830.83
Res-Unet [21] 0.930.93\mathbf{0.93}bold_0.93 0.830.830.830.83 0.910.91\mathbf{0.91}bold_0.91 0.870.87\mathbf{0.87}bold_0.87 0.740.740.740.74 0.840.84\mathbf{0.84}bold_0.84
MedSegDiff * [9] 0.890.890.890.89 0.810.810.810.81 0.870.870.870.87 0.820.820.820.82 0.700.700.700.70 0.780.780.780.78
LDSeg* (Ours) 0.930.93\mathbf{0.93}bold_0.93 0.850.85\mathbf{0.85}bold_0.85 0.910.91\mathbf{0.91}bold_0.91 0.870.87\mathbf{0.87}bold_0.87 0.750.75\mathbf{0.75}bold_0.75 0.840.84\mathbf{0.84}bold_0.84
* denotes the DPMs.
TABLE II: Quantitative results for GlaS data segmentation.
Method DSC \uparrow IoU \uparrow
U-net [24] 0.780.780.780.78 0.650.650.650.65
U-net++ [26] 0.780.780.780.78 0.660.660.660.66
Res-Unet [21] 0.790.790.790.79 0.660.660.660.66
MedT [27] 0.81 0.70
SDF-DDPM * [11] 0.830.830.830.83 0.720.720.720.72
MedSegDiff * [9] 0.840.840.840.84 0.740.740.740.74
LDSeg* (Ours) 0.860.86\mathbf{0.86}bold_0.86 0.760.76\mathbf{0.76}bold_0.76
* denotes the DPMs.
TABLE III: Quantitative results for Knee data segmentation.
Method DSC \uparrow IoU \uparrow
FC TC FC+TC FC TC FC+TC
Res-Unet [21] 0.970.97\mathbf{0.97}bold_0.97 0.960.96\mathbf{0.96}bold_0.96 0.960.96\mathbf{0.96}bold_0.96 0.930.93\mathbf{0.93}bold_0.93 0.930.93\mathbf{0.93}bold_0.93 0.930.93\mathbf{0.93}bold_0.93
MedSegDiff \dagger\star† ⋆ [9] 0.050.050.050.05 0.010.010.010.01 0.040.040.040.04 0.030.030.030.03 0.010.010.010.01 0.020.020.020.02
LDSeg* (Ours) 0.960.960.960.96 0.960.96\mathbf{0.96}bold_0.96 0.960.96\mathbf{0.96}bold_0.96 0.930.93\mathbf{0.93}bold_0.93 0.920.920.920.92 0.930.93\mathbf{0.93}bold_0.93
\dagger partial implementation due to memory shortage. \star denotes the DPMs.

We evaluated the performance of our proposed method using two standard metrics: (1) Dice Similarity Co-efficient (DSC) and (2) Intersection over Union (IoU). Table I, II and III shows the quantitative results for different methods for Echo, GlaS and Knee datasets, respectively. LDSeg achieves best DSC and IoU scores for all the datasets. SDF-DDPM method uses signed distance function to represent mask images, which is ambiguous for data with multi-labels. Hence, it is only shown for GlaS dataset. For 3D Knee dataset, it was impossible to implement the full architecture for MedSegDiff due to GPU memory shortage, and we implemented it partially with removing one intermediate convolution and attention layer. This indicates that diffusion in latent space is absolutely necessary for 3D medical images with large image size when the GPU memory is constrained.

V-B Computational Efficiency

Refer to caption
Figure 3: Number of evenly spaced sampling steps vs DSC for different datasets. DDPM achieves maximum segmentation accuracy with fewer sampling steps than DDIM algorithm. For DDPM, the minimum number of sampling steps to achieve maximum segmentation accuracy is 10101010, 10101010 and 15151515 for Echo, GlaS and Knee data, respectively. Number of steps are plotted in logarithmic scale for convenience.

The major difference of LDSeg to other traditional diffusion based segmentation methods is that the diffusion happens in the latent low-dimensional space. Obviously, total sampling time of LDSeg for a sampling sequence must be less than the other methods as it is computationally inexpensive with less memory consumption. We further experimented on the sampling sequence for the reverse process. Nichol et al. [14] observed that the model trained with ``cosine"``𝑐𝑜𝑠𝑖𝑛𝑒"``cosine"` ` italic_c italic_o italic_s italic_i italic_n italic_e " noise scheduler performed remarkably well in generating natural images with few sampling steps (<50)absent50(<50)( < 50 ) having close to optimal FID score. They used K𝐾Kitalic_K evenly spaced real numbers between 1111 and T𝑇Titalic_T (inclusive) as sampling steps, and then rounded each resulting number to the nearest integer value. We adapted the same sampling strategy and observed that with very few sampling steps (15)absent15(\leq 15)( ≤ 15 ), LDSeg achieves maximum segmentation accuracy (same as using all the sampling steps) for all the test datasets. We also implemented DDIM proposed by Song et al., which deterministically maps noises to images without added stochasticity in the transitional states. In our experiments, we observed that with added stochasticity of DDPM sampler, LDSeg always performed better than DDIM sampler. Fig. 3 shows the number of sampling steps vs DSC scores for all the datasets for DDPM and DDIM sampling algorithms. Table IV shows the segmentation accuracy for total and minimum number of sampling steps to reach maximum accuracy, along with execution time (seconds) needed to segment a single image with CPU. With minimum number of sampling steps 15absent15\leq 15≤ 15, LDSeg achieved a significant boost in sampling time efficiency (70similar-toabsent70\sim 70∼ 70 times reduction of execution time).

TABLE IV: Segmentation accuracy (DSC, IoU) for different datasets using all the sampling steps and minimum number of sampling steps (to achieve maximum DSC, IoU using all the steps). Execution (exec.) time for segmenting a single image for each dataset with CPU is shown for both the number of sampling steps.
Dataset Using all the sampling steps Using minimum number of sampling steps
Steps Exec. time (s𝑠sitalic_s) DSC\uparrow IoU\uparrow Steps Exec. time (s𝑠sitalic_s) DSC\uparrow IoU\uparrow
Echo 1000100010001000 70.8270.8270.8270.82 0.910.910.910.91 0.840.840.840.84 10101010 0.780.780.780.78 0.910.910.910.91 0.840.840.840.84
GlaS 1000100010001000 71.7971.7971.7971.79 0.860.860.860.86 0.770.770.770.77 10101010 0.760.760.760.76 0.860.860.860.86 0.760.760.760.76
Knee 1000100010001000 69.6869.6869.6869.68 0.960.960.960.96 0.930.930.930.93 15151515 1.291.291.291.29 0.960.960.960.96 0.930.930.930.93
Refer to caption
Figure 4: a. Number of sampling steps vs DSC for LDSeg and MedSegDiff for GlaS dataset. LDSeg achieves optimal DSC with less sampling than MedSegDiff. b. Image size vs execution time for segmenting a single image with different DPMs and Res-Unet. As expected, Res-Unet being a deterministic model is instantaneous. Execution time for LDSeg remains constant due to contrained low-dimensional latent space, whereas for SDF-DDPM and MedSegDiff it increases exponentially with the increment of image size.

We further investigated the execution time to segment a single image corresponding to different image size for different DPMs. The minimum number of sampling steps to achieve maximum segmentation accuracy can be different (Fig. 4a) for different DPMs due to different objectives of learning target noise distributions. For a fair comparison, we fixed the total sampling steps to 50505050 and experimented on different data size and utilized GlaS dataset as a test case. Fig. 4b shows that with the increment of image size, execution time for SDF-DDPM and MedSegDiff increases exponentially, as the reverse process happens on actual image dimension. For LDSeg, image size is relatively irrelevant as even with the increment of image size, low-dimensional latent space does not change much, hence the execution times are close to constant.

V-C Robustness to noise

Refer to caption
Figure 5: a-b. added noise variance σ𝜎\sigmaitalic_σ vs DSC scores for Res-Unet and LDSeg for Echo and Knee dataset, repectively. c-d. Top and bottom rows of each block show some sample images/slices from Echo and Knee dataset along with the corresponding groud truth (GT) and segmentation results overlay of Res-Unet and LDSeg, respectively.

One of the key challenges for medical image segmentation is to produce accurate segmentation from noisy image acquisition. Often times, deterministic segmentation models fail in the presence of noise in the test dataset. As the denoiser in LDSeg is conditioned on image embedding, which is a low-dimensional representation of the source image, intuitively it should be more robust to the high frequency noises present in the source image. Moreover, the iterative process of the denoiser naturally removes the noises, which even in the presence of noises in the image embedding, produces a clean m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, hence accurate segmentation can be obtained using the mask decoder. To test the robustness to the noises, we have generated noisy image data from their clean counterparts by,

Iσ=I+𝒩(0,σ)subscript𝐼𝜎𝐼𝒩0𝜎I_{\sigma}=I+\mathcal{N}(0,\sigma)italic_I start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT = italic_I + caligraphic_N ( 0 , italic_σ ) (10)

where I𝐼Iitalic_I is a sample from test data and σ𝜎\sigmaitalic_σ is the noise variance. Fig. 5(a-b) shows the DSC scores for LDSeg and a deterministic model Res-Unet against different variances of added noise on the echo and knee dataset, repectively. LDSeg shows strong robustness to added noise even for σ=0.2𝜎0.2\sigma=0.2italic_σ = 0.2 and relatively maintains the optimal segmentation accuracy throughout. Whereas, the Res-Unet accuracy drops drastically with the added noise to the source image. Fig. 5(c) shows sample images from each dataset and the corresponding segmentation results for LDSeg and Res-Unet.

VI Ablation Study

Refer to caption
Figure 6: a. A sample test image along with its mask and predictions for different variants of LDSeg. b. Number of sampling steps vs DSC for different variants of LDSeg on GlaS dataset.

Two major components that distinguishes LDSeg from other diffusion based segmentation models are the mask autoencoder and the image encoder. We tested the effectiveness of both the components by creating several variants of LDSeg:

  1. 1.

    LDSeg: The proposed framework that uses both the mask autoencoder and image encoder.

  2. 2.

    LDSeg(md): The mask encoder is replaced with a mask down-sampler that down-samples mask image to the same size of m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The image encoder is unchanged. Final segmentation is obtained using a mask up-sampler.

  3. 3.

    LDSeg(id): The image encoder is replaced with a image down-sampler that down-samples source image to the same size of m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The mask encoder is unchanged.

  4. 4.

    LDSeg(md,id): Both the mask autoencoder and the image encoder are replaced with mask and image down-sampler.

TABLE V: Ablation study.
Method Mask Image Mask Image DSC IoU
Autoencoder Encoder Down-sample Down-sample
LDSeg(md,id) 0.510.510.510.51 0.350.350.350.35
LDSeg(id) 0.550.550.550.55 0.390.390.390.39
LDSeg(md) 0.760.760.760.76 0.620.620.620.62
LDSeg 0.860.860.860.86 0.760.760.760.76
md \rightarrow Mask Down-sampled, id \rightarrow Image Down-sampled.

Table V shows the results of the ablation study for the GlaS dataset. The models with direct down-sampling by nearest neighbor interpolation of the image/mask performs poorly. Fig. 6a shows an example image segmentation with different models. Fig. 6b shows the number of sampling steps vs DSC scores. A key thing to notice here that all the models reaches optimal segmentation accuracy with very fewer steps like the proposed LDSeg. This indicates that denoiser trained on a low-dimensional image space has superior noise prediction capability in general.

VII Discussion

In case of medical image dataset, often times the dataset consists of 3D scans and cannot be down-sampled without loosing important imaging features due to complex tissue structures, organ-to-organ surface interaction etc. LDSeg can be directly used in the larger 3D datasets, where other traditional DPMs may not be even implementable due to lack of GPU/CPU memory. On top of that, faster sampling in the reverse process makes it computationally efficient. Furthermore, the method is significantly robust to noises present in the source images than the traditional deterministic segmentation models, which mitigates the noisy image acquisition problems. A key challenge for the deterministic segmentation models is to measure prediction uncertainty. LDSeg, being generative in nature, can estimate prediction uncertainty by obtaining standard deviation of predictions from multiple runs. Fig. 7 shows an example of uncertain regions on object boundary estimation using LDSeg.

Refer to caption
Figure 7: An example of Uncertainty estimation of Echo dataset. a. A sample Echo frame with marked unclear LV and LA walls (orange arrows). b. Mean segmentation map using 100 sampling runs. c. Obtained standard deviation (SD) map from the 100 sampling run. Orange arrows show the highly uncertain regions with maximum SD that correlates with location in a.

A limitation of the proposed approach is the low-dimensional image embedding learning for complex medical imaging datasets. As the data complexity increases in terms of tissue structures with various distribution, it is impossible to learn a proper image embedding preserving all the final details, which may hamper the denoising process of the denoiser. One way to address this problem would be to learn different frequency patterns of the input images by the image encoder to enforce additional conditioning on denoiser.

VIII Conclusion

Adapting DPMs in medical image segmentation is fairly challenging due to large image sizes as well as complex tissue structures and noisy image acquisitions. We present a novel diffusion based framework leveraging the learned latent space that is extremely fast in training/inference phase as well significantly robust to noises present in the source image. This can also pave the way to resolve the domain shift problem in medical image segmentation, where source images can be obtained from different institutes, scanners with various imaging modalities.

IX Acknowledgement

This research was supported in part by NIH Grants R56EB004640, R01HL171624, R01AG067078 and R01EB019961. There is no other conflicts of interest.

The OAI is a public-private partnership comprised of five contracts (N01-AR-2-2258; N01-AR-2-2259; N01-AR-2-2260; N01-AR-2-2261; N01-AR-2-2262) funded by the National Institutes of Health, a branch of the Department of Health and Human Services, and conducted by the OAI Study Investigators. Private funding partners include Merck Research Laboratories; Novartis Pharmaceuticals Corporation, GlaxoSmithKline; and Pfizer, Inc. Private sector funding for the OAI is managed by the Foundation for the National Institutes of Health. This manuscript was prepared using an OAI public use data set and does not necessarily reflect the opinions or views of the OAI investigators, the NIH, or the private funding partners.

References

  • [1] Preeti Aggarwal, R. Vig, S. Bhadoria, and C. Dethe. Role of segmentation in medical imaging: A comparative study. International Journal of Computer Applications, 29:54–61, 2011.
  • [2] Mohammad Hesam Hesamian, Wenjing Jia, Xiangjian He, and Paul Kennedy. Deep learning techniques for medical image segmentation: Achievements and challenges. Journal of Digital Imaging, 32, 05 2019.
  • [3] Risheng Wang, Tao Lei, Ruixia Cui, Bingtao Zhang, Hongying Meng, and Asoke K. Nandi. Medical image segmentation using deep learning: A survey. IET Image Processing, 16(5):1243–1267, 2022.
  • [4] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in neural information processing systems, 33:6840–6851, 2020.
  • [5] Jonathan Ho, Chitwan Saharia, William Chan, David J Fleet, Mohammad Norouzi, and Tim Salimans. Cascaded diffusion models for high fidelity image generation. The Journal of Machine Learning Research, 23(1):2249–2281, 2022.
  • [6] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32, 2019.
  • [7] Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021.
  • [8] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in neural information processing systems, 34:8780–8794, 2021.
  • [9] Junde Wu, Huihui Fang, Yu Zhang, Yehui Yang, and Yanwu Xu. Medsegdiff: Medical image segmentation with diffusion probabilistic model. arXiv preprint arXiv:2211.00611, 2022.
  • [10] Junde Wu, Rao Fu, Huihui Fang, Yu Zhang, and Yanwu Xu. Medsegdiff-v2: Diffusion based medical image segmentation with transformer. arXiv preprint arXiv:2301.11798, 2023.
  • [11] Lea Bogensperger, Dominik Narnhofer, Filip Ilic, and Thomas Pock. Score-based generative models for medical image segmentation using signed distance functions, 2023.
  • [12] Aimon Rahman, Jeya Maria Jose Valanarasu, Ilker Hacihaliloglu, and Vishal Patel. Ambiguous medical image segmentation using diffusion models. 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 11536–11546, 2023.
  • [13] Fahim Ahmed Zaman, Mathews Jacob, Amanda Chang, Kan Liu, Milan Sonka, and Xiaodong Wu. Surf-cdm: Score-based surface cold-diffusion model for medical image segmentation. arXiv preprint arXiv:2312.12649, 2023.
  • [14] Alexander Quinn Nichol and Prafulla Dhariwal. Improved denoising diffusion probabilistic models. In International Conference on Machine Learning, pages 8162–8171. PMLR, 2021.
  • [15] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. arXiv preprint arXiv:2010.02502, 2020.
  • [16] Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver: A fast ode solver for diffusion probabilistic model sampling in around 10 steps. Advances in Neural Information Processing Systems, 35:5775–5787, 2022.
  • [17] Yunke Wang, Xiyu Wang, Anh-Dung Dinh, Bo Du, and Charles Xu. Learning to schedule in diffusion probabilistic models. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pages 2478–2488, 2023.
  • [18] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 10684–10695, 2022.
  • [19] Yuxin Mao, Jing Zhang, Mochu Xiang, Yunqiu Lv, Yiran Zhong, and Yuchao Dai. Contrastive conditional latent diffusion for audio-visual segmentation. arXiv preprint arXiv:2307.16579, 2023.
  • [20] Koutilya Pnvr, Bharat Singh, Pallabi Ghosh, Behjat Siddiquie, and David Jacobs. Ld-znet: A latent diffusion approach for text-based image segmentation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4157–4168, 2023.
  • [21] Zhengxin Zhang, Qingjie Liu, and Yunhong Wang. Road extraction by deep residual u-net. IEEE Geoscience and Remote Sensing Letters, 15(5):749–753, 2018.
  • [22] Fahim Zaman, Rakesh Ponnapureddy, Yi Grace Wang, Amanda Chang, Linda M Cadaret, Ahmed Abdelhamid, Shubha D Roy, Majesh Makan, Ruihai Zhou, Manju B Jayanna, et al. Spatio-temporal hybrid neural networks reduce erroneous human “judgement calls” in the diagnosis of takotsubo syndrome. EClinicalMedicine, 40, 2021.
  • [23] Korsuk Sirinukunwattana, Josien P. W. Pluim, Hao Chen, Xiaojuan Qi, Pheng-Ann Heng, Yun Bo Guo, Li Yang Wang, Bogdan J. Matuszewski, Elia Bruni, Urko Sanchez, Anton Böhm, Olaf Ronneberger, Bassem Ben Cheikh, Daniel Racoceanu, Philipp Kainz, Michael Pfeiffer, Martin Urschler, David R. J. Snead, and Nasir M. Rajpoot. Gland segmentation in colon histology images: The glas challenge contest, 2016.
  • [24] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18, pages 234–241. Springer, 2015.
  • [25] Fausto Milletari, Nassir Navab, and Seyed-Ahmad Ahmadi. V-net: Fully convolutional neural networks for volumetric medical image segmentation. In 2016 fourth international conference on 3D vision (3DV), pages 565–571. Ieee, 2016.
  • [26] Zongwei Zhou, Md Mahfuzur Rahman Siddiquee, Nima Tajbakhsh, and Jianming Liang. Unet++: A nested u-net architecture for medical image segmentation. In Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support: 4th International Workshop, DLMIA 2018, and 8th International Workshop, ML-CDS 2018, Held in Conjunction with MICCAI 2018, Granada, Spain, September 20, 2018, Proceedings 4, pages 3–11. Springer, 2018.
  • [27] Jeya Maria Jose Valanarasu, Poojan Oza, Ilker Hacihaliloglu, and Vishal M Patel. Medical transformer: Gated axial-attention for medical image segmentation. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part I 24, pages 36–46. Springer, 2021.