Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
11institutetext: Department of Electronic and Computer Engineering, The Hong Kong University of Science and Technology, Hong Kong, China
11email: {mkfmelbatel,eexmli}@.ust.hk
22institutetext: Department of Engineering Science, University of Oxford, Oxford, UK
22email: konstantinos.kamnitsas@eng.ox.ac.uk
33institutetext: HKUST Shenzhen-Hong Kong Collaborative Innovation Research Institute, Shenzhen, China 44institutetext: Department of Computing, Imperial College London, London, UK 55institutetext: School of Computer Science, University of Birmingham, Birmingham, UK

An Organism Starts with a Single Pix-Cell:         A Neural Cellular Diffusion for High-Resolution Image Synthesis

Marawan Elbatel 11    Konstantinos Kamnitsas 224455    Xiaomeng Li Correspondence: 1133 eexmli@ust.hk
Abstract

Generative modeling seeks to approximate the statistical properties of real data, enabling synthesis of new data that closely resembles the original distribution. Generative Adversarial Networks (GANs) and Denoising Diffusion Probabilistic Models (DDPMs) represent significant advancements in generative modeling, drawing inspiration from game theory and thermodynamics, respectively. Nevertheless, the exploration of generative modeling through the lens of biological evolution remains largely untapped. In this paper, we introduce a novel family of models termed Generative Cellular Automata (GeCA), inspired by the evolution of an organism from a single cell. GeCAs are evaluated as an effective augmentation tool for retinal disease classification across two imaging modalities: Fundus and Optical Coherence Tomography (OCT). In the context of OCT imaging, where data is scarce and the distribution of classes is inherently skewed, GeCA significantly boosts the performance of 11 different ophthalmological conditions, achieving a 12% increase in the average F1 score compared to conventional baselines. GeCAs outperform both diffusion methods that incorporate UNet or state-of-the art variants with transformer-based denoising models, under similar parameter constraints. Code is available at: https://github.com/xmed-lab/GeCA.

Keywords:
Generative Cellular Automata (GeCA) Diffusion Models
Refer to caption
Figure 1: Selected synthetic images from our GeCA trained on Fundus and OCT.

1 Introduction

Retinal diseases rank among the leading causes of vision disabilities and blindness if they remain untreated. Medical imaging modalities such as fundus photography and Optical Coherence Tomography (OCT) are widely used for diagnosing retinal conditions. OCT, offering a comprehensive view of the retinal layers compared to the fundus, is the preferred modality for diagnosing specific diseases such as Diabetic Retinopathy (DR) and Age-related Macular Degeneration (AMD) [17]. Recently, deep learning approaches have been introduced for retinal disease screening, utilizing both fundus [15] and OCT [30]. Nevertheless, the development of these approaches is significantly hindered by the scarcity of publicly accessible datasets, particularly for OCT. Despite its advantages, OCT imaging is more costly and less employed than fundus photography, leading to a scarcity of OCT datasets. Therefore, it becomes crucial to develop a novel solution for retinal disease diagnosis using OCT imaging, especially considering its scarcity as well as its skewed disease distribution.

Expanding datasets with synthetic images through generative modeling has been shown to significantly enhance diagnostic accuracy in medical imaging, particularly in scenarios where data is scarce and class distribution is skewed [6, 32, 20]. Current generative models primarily utilize diffusion-based optimization [8], relying heavily on architectures such as UNet [24, 6] and transformers [3, 23]. Despite their effectiveness, these models require a vast number of parameters, training on large-scale datasets, and often segmentation priors [36]. These inefficiencies present considerable challenges, particularly in medical imaging, where datasets, annotations, and computational resources are often scarce. Inspired by biological processes, Neural Cellular Automata (NCA) [18] emerge as a promising alternative, offering advancements in diverse tasks with fewer parameters [27, 21, 13, 11]. While NCA have shown promise in enabling medical image segmentation tasks under resource-constrained settings [13, 11], their application in generative tasks results in low-resolution outputs [22, 26, 12] and lacks comprehensive performance comparisons, particularly in the evaluation of downstream tasks, where NCA’s efficiency for image generation remains an unresolved challenge.

To address these challenges, we propose a novel approach for incorporating NCA in image generation by integrating diffusion objectives specifically devised for NCA’s unique structure. Operating in the latent space, scaling Neural Cellular Automata (NCA) with transformers, and introducing a novel Gene Heredity guidance method for enhanced reverse sampling, we present Generative Cellular Automata (GeCA). GeCA surpasses the state-of-the-art Diffusion Transformers (DiTs) [23] in image generation across two modalities: Fundus and OCT. By extending the application of GeCA to dataset expansion, we augment the scarce OCT dataset with synthetic images, resulting in a 12% improvement in the average F1-score for multi-label retinal disease classification compared to conventional baselines. Our contributions can be summarized as:

  • We introduce Generative Cellular Automata (GeCA), a novel model that integrates Neural Cellular Automata (NCA) with diffusion objectives, tailored specifically for NCA’s unique structure.

  • We propose Gene Heredity Guidance (GHG) to improve GeCA’s image sampling. GHG enabled GeCA to surpass SOTA DiT in image generation and retinal disease classification with half of DiT’s parameters.

  • Through a detailed examination of diffusion models in OCT image generation, we demonstrate their capability to augment training datasets with synthetic images, boosting OCT’s multi-label retinal disease classification.

Refer to caption
Figure 2: GeCA overall framework.

2 Generative Cellular Automata

2.1 An Organism Starts With a Single Pix-Cell

NCA [18] model an input image with height H𝐻Hitalic_H and width W𝑊Witalic_W as a grid comprising H×W𝐻𝑊H\times Witalic_H × italic_W entities, which we refer to as pix-cells in our methodology. Each pix-cell represents a time-dependent state space representation, facilitating dynamic evolution akin to cellular processes in an organism, i.e., image. We parameterize the state of each pix-cell at step m𝑚mitalic_m as a vector of scalars, defined as:

pix-cellm={Cin,Cγ,Cout,Ch},subscriptpix-cell𝑚superscript𝐶𝑖𝑛superscript𝐶𝛾superscript𝐶𝑜𝑢𝑡superscript𝐶\text{{pix-cell}}_{{m}}=\{C^{in},C^{\gamma},C^{out},C^{h}\},pix-cell start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = { italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT , italic_C start_POSTSUPERSCRIPT italic_γ end_POSTSUPERSCRIPT , italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT , italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT } , (1)

where Cinsuperscript𝐶𝑖𝑛C^{in}italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT denotes the input values of the image (e.g., one scalar for grayscale and three for RGB input images), Cγsuperscript𝐶𝛾C^{\gamma}italic_C start_POSTSUPERSCRIPT italic_γ end_POSTSUPERSCRIPT represents a positional encoding, defined by a continuous and smooth sinusoidal function facilitating spatial awareness within the grid [5, 29], Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT refers to the output values indicating a pix-cells’s targeted state (equivalent to Cinsuperscript𝐶𝑖𝑛C^{in}italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT in dimension), and Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT represents the hidden state variables reflecting the pix-cell’s internal state during evolution.

To evolve a single pix-cell to a more complex organism—an image, we follow traditional NCA conventional that adopts a stochastic rule [18]. This means a pix-cell is updated at step m𝑚mitalic_m randomly with a probability p𝑝pitalic_p, reflecting the non-simultaneous nature of cellular updates in self-organizing organisms. The update of a pix-cell focuses on updating only Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT and Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT, given that Cinsuperscript𝐶𝑖𝑛C^{in}italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT and Cγsuperscript𝐶𝛾C^{\gamma}italic_C start_POSTSUPERSCRIPT italic_γ end_POSTSUPERSCRIPT are constant. This process, illustrated as GeCA step in Fig. 2, is defined as:

pix-cellm+1=Θ(pix-cellm,Neighborhood8)+{0,0,Cmout,Cmh}subscriptpix-cell𝑚1Θsubscriptpix-cell𝑚subscriptNeighborhood800subscriptsuperscript𝐶𝑜𝑢𝑡𝑚subscriptsuperscript𝐶𝑚\textit{pix-cell}_{m+1}=\Theta(\textit{pix-cell}_{m},\text{Neighborhood}_{8})+% \{0,0,C^{out}_{m},C^{h}_{m}\}pix-cell start_POSTSUBSCRIPT italic_m + 1 end_POSTSUBSCRIPT = roman_Θ ( pix-cell start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , Neighborhood start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT ) + { 0 , 0 , italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } (2)

Departing from the hierarchical modeling with M𝑀Mitalic_M layers in the SOTA Diffusion Transformer (DiT), we parameterize ΘΘ\Thetaroman_Θ as a single DiT block featuring a localized self-attention mechanism, specifically computed across the 8 closest neighboring pix-cells. The localized attention strategy, implemented similarly to those in localized transformer-based methods [33, 2, 27], allows each pix-cell to grow independently by applying Eq. 2 for M𝑀Mitalic_M times, using the same ΘΘ\Thetaroman_Θ. GeCA’s approach shifts the focus in image generation towards local spatial interactions, moving away from the global context reliance observed in traditional models such as UNet [25] and standard transformers [29]. Nevertheless, GeCA achieves global coherence by accumulating long-term state-space representation via Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT, aligning with the foundational concepts documented in NCA [18, 26, 27, 11, 13, 21], Mamba [7], universal transformers [4], and MLP-mixers [28].

2.2 Cellular Diffusion: Evolving Cells into Organisms

To train our model parameters ΘΘ\Thetaroman_Θ, we utilize the well-established diffusion process first introduced in [8] with specific modifications in the forward and reverse steps. During the forward diffusion process, we initialize Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT and Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT with zeros, except for a single pix-cell located at the center of the H×W𝐻𝑊H\times Witalic_H × italic_W grid, which is initialized with random scalars to serve as the starting point for the cellular process. Cγsuperscript𝐶𝛾C^{\gamma}italic_C start_POSTSUPERSCRIPT italic_γ end_POSTSUPERSCRIPT is initialized with a sinusoidal positional encoding. Cinsuperscript𝐶𝑖𝑛C^{in}italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT can be described in the forward diffusion process on a per pix-cell level as:

Ctin=αtC0in+1αtϵ,ϵ𝒩(0,I),formulae-sequencesubscriptsuperscript𝐶𝑖𝑛𝑡subscript𝛼𝑡subscriptsuperscript𝐶𝑖𝑛01subscript𝛼𝑡italic-ϵsimilar-toitalic-ϵ𝒩0𝐼C^{in}_{t}=\sqrt{\alpha_{t}}C^{in}_{0}+\sqrt{1-\alpha_{t}}\epsilon,\quad% \epsilon\sim\mathcal{N}(0,I),italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = square-root start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + square-root start_ARG 1 - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_ϵ , italic_ϵ ∼ caligraphic_N ( 0 , italic_I ) , (3)

where ϵitalic-ϵ\epsilonitalic_ϵ, following a normal distribution, represents the noise added at each step, and αtsubscript𝛼𝑡\alpha_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which is part of a pre-defined variance schedule, takes values within the interval (0,1)01(0,1)( 0 , 1 ) for each time step t=1𝑡1t=1italic_t = 1 to T𝑇Titalic_T.

We then perform M𝑀Mitalic_M cellular updates with Eq. 2 to developing Ctoutsubscriptsuperscript𝐶𝑜𝑢𝑡𝑡C^{out}_{t}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and Cthsubscriptsuperscript𝐶𝑡C^{h}_{t}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. When T𝑇T\rightarrow\inftyitalic_T → ∞, CTinsubscriptsuperscript𝐶𝑖𝑛𝑇C^{in}_{T}italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT becomes equivalent to an isotropic Gaussian distribution [8]. Thus, the optimization process is simplified from a theoretical formulation to predict the noise ϵitalic-ϵ\epsilonitalic_ϵ from a pix-cell as:

L=𝔼t[1,T],C0,t[ϵCtout2]𝐿subscript𝔼similar-to𝑡1𝑇subscript𝐶0𝑡delimited-[]superscriptnormitalic-ϵsubscriptsuperscript𝐶𝑜𝑢𝑡𝑡2L=\mathbb{E}_{t\sim[1,T],C_{0,t}}\left[\|\epsilon-C^{out}_{t}\|^{2}\right]italic_L = blackboard_E start_POSTSUBSCRIPT italic_t ∼ [ 1 , italic_T ] , italic_C start_POSTSUBSCRIPT 0 , italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_ϵ - italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (4)

This formulation allows reverse sampling from a Gaussian noise CTin𝒩(0,𝐈)similar-tosubscriptsuperscript𝐶𝑖𝑛𝑇𝒩0𝐈C^{in}_{T}\sim\mathcal{N}(0,\mathbf{I})italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , bold_I ). Additionally, it allows adjusting M𝑀Mitalic_M during sampling to control the intensity of generation, from undergrowth to overgrowth; see Fig. 6 in the appendix

2.3 Improved Reverse Sampling via Gene Heredity

Representing an input image with pix-cells, a time-dependent state space representation, GeCA preserves long-term information within its internal hidden states, Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT, analogous to genetic material. Thus, we propose leveraging Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT at time t+1𝑡1t+1italic_t + 1 to guide the reverse generation of time t𝑡titalic_t, mirroring the inheritance of genetic traits. Specifically, we modify each step in the reverse process to initiate the pix-cell hidden states, Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT, as:

Cth={ϵ𝒩(0,I)if t=T and grid-center pix-cell,Ct+1hotherwise.subscriptsuperscript𝐶𝑡casessimilar-toitalic-ϵ𝒩0𝐼if 𝑡𝑇 and grid-center pix-cellsubscriptsuperscript𝐶𝑡1otherwiseC^{h}_{t}=\begin{cases}\epsilon\sim\mathcal{N}(0,I)&\text{if }t=T\text{~{}and % grid-center {pix-cell}},\\ C^{h}_{t+1}&\text{otherwise}.\end{cases}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { start_ROW start_CELL italic_ϵ ∼ caligraphic_N ( 0 , italic_I ) end_CELL start_CELL if italic_t = italic_T and grid-center italic_pix-cell , end_CELL end_ROW start_ROW start_CELL italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_CELL start_CELL otherwise . end_CELL end_ROW (5)

Simultaneously, Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT, for the grid-center pix-cell at each timestep is defined as:

Ctout𝒩(0,I)similar-tosubscriptsuperscript𝐶𝑜𝑢𝑡𝑡𝒩0𝐼C^{out}_{t}\sim\mathcal{N}(0,I)italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , italic_I ) (6)

Our proposed process, termed Gene Heredity Guidance (GHG), sets the stage for denoising Ctinsubscriptsuperscript𝐶𝑖𝑛𝑡C^{in}_{t}italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and refining Cthsubscriptsuperscript𝐶𝑡C^{h}_{t}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from a plausible starting point. Following GHG, the denoising process to sample a synthetic pix-cell, C0insubscriptsuperscript𝐶𝑖𝑛0C^{in}_{0}italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, adheres to traditional diffusion steps till t0𝑡0t\rightarrow 0italic_t → 0 as:

Ct1in=1αt(Ctin1αt1αt1Ctout),subscriptsuperscript𝐶𝑖𝑛𝑡11subscript𝛼𝑡subscriptsuperscript𝐶𝑖𝑛𝑡1subscript𝛼𝑡1subscript𝛼𝑡1subscriptsuperscript𝐶𝑜𝑢𝑡𝑡C^{in}_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(C^{in}_{t}-\frac{1-\alpha_{t}}{% \sqrt{1-\alpha_{t-1}}}C^{out}_{t}\right),italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_ARG ( italic_C start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG 1 - italic_α start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT end_ARG end_ARG italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (7)

Note that without our proposed GHG, the application of NCA in generative modeling is suboptimal (See Fig. 5).

2.4 Retinal Disease Classification

Classifying retinal disease from OCT images presents significant challenges due to data scarcity and skewed class distributions. In light of these challenges, we leverage generative modeling to augment the dataset effectively, a strategy proven to significantly enhance downstream classification tasks compared to conventional augmentation techniques [6, 35].

Following [35], we synthesize a training set expanded five-fold, mirroring the original training set’s distribution. Given the original dataset’s class distribution porig(y)subscript𝑝orig𝑦p_{\text{orig}}(y)italic_p start_POSTSUBSCRIPT orig end_POSTSUBSCRIPT ( italic_y ), with y𝑦yitalic_y representing the dataset labels and Norigsubscript𝑁origN_{\text{orig}}italic_N start_POSTSUBSCRIPT orig end_POSTSUBSCRIPT as the original dataset size, the objective is to expand the dataset five-fold to Naug=5×Norigsubscript𝑁aug5subscript𝑁origN_{\text{aug}}=5\times N_{\text{orig}}italic_N start_POSTSUBSCRIPT aug end_POSTSUBSCRIPT = 5 × italic_N start_POSTSUBSCRIPT orig end_POSTSUBSCRIPT, while preserving porig(y)subscript𝑝orig𝑦p_{\text{orig}}(y)italic_p start_POSTSUBSCRIPT orig end_POSTSUBSCRIPT ( italic_y ). This is achieved by ensuring that the count of each label y𝑦yitalic_y in the augmented dataset, Countaug(y)subscriptCountaug𝑦\text{Count}_{\text{aug}}(y)Count start_POSTSUBSCRIPT aug end_POSTSUBSCRIPT ( italic_y ), is five times its original count as:

paug(y)=porig(y),whereCountaug(y)=5×Countorig(y)formulae-sequencesubscript𝑝aug𝑦subscript𝑝orig𝑦wheresubscriptCountaug𝑦5subscriptCountorig𝑦p_{\text{aug}}(y)=p_{\text{orig}}(y),\quad\text{where}\quad\text{Count}_{\text% {aug}}(y)=5\times\text{Count}_{\text{orig}}(y)italic_p start_POSTSUBSCRIPT aug end_POSTSUBSCRIPT ( italic_y ) = italic_p start_POSTSUBSCRIPT orig end_POSTSUBSCRIPT ( italic_y ) , where Count start_POSTSUBSCRIPT aug end_POSTSUBSCRIPT ( italic_y ) = 5 × Count start_POSTSUBSCRIPT orig end_POSTSUBSCRIPT ( italic_y ) (8)

By preserving the original label distribution porig(y)subscript𝑝orig𝑦p_{\text{orig}}(y)italic_p start_POSTSUBSCRIPT orig end_POSTSUBSCRIPT ( italic_y ) in the augmented dataset, we maintain the dataset’s inherent distribution to avoid any potential bias.

Refer to caption
Figure 3: Results summary on a public fundus dataset.

3 Experiments

Datasets. We evaluate our model using two different datasets: OCT and Fundus. The multi-label OCT dataset, OCT-ML, is an in-house dataset consisting of 1435 samples from 369 eyes of 203 patients considering multiple diseases including normal, dry age-related macular degeneration (dAMD), wet age-related macular degeneration (wAMD), diabetic retinopathy (DR), central serous chorioretinopathy (CSC), pigment epithelial detachment (PED), macular epiretinal membrane (MEM), fluid (FLD), exudation (EXU), choroid neovascularization (CNV), and retinal vascular occlusion (RVO). Additionally, we provide the code necessary for both the generation process and the classification task, applied to DeepDRiD [16], a publicly available fundus imaging dataset encompassing five grading classes and follow the MedMnist split [31] (1,080 train, 120 val, 400 test). For the OCT-ML dataset, we adopt a five-fold cross-validation.

Table 1: Quantitative image quality evaluation for two datasets. KID values are expressed in terms of 103superscript10310^{-3}10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT for each model. All baselines are trained and evaluated with classifier free guidance (CFG) [9] and T𝑇Titalic_T = 250.
Method # Param. (\downarrow) Fundus Dataset OCT Dataset
KID (\downarrow) LPIPS (\uparrow) GG (>0absent0>0> 0) KID (\downarrow) LPIPS (\uparrow) GG (>0absent0>0> 0)
LDM-B [24] 17.3 M 11.64±2.1subscript11.64plus-or-minus2.111.64_{\pm 2.1}11.64 start_POSTSUBSCRIPT ± 2.1 end_POSTSUBSCRIPT 0.37±0.09subscript0.37plus-or-minus0.090.37_{\pm 0.09}0.37 start_POSTSUBSCRIPT ± 0.09 end_POSTSUBSCRIPT 10.6710.67-10.67- 10.67 64.5±10subscript64.5plus-or-minus1064.5_{\pm 10}64.5 start_POSTSUBSCRIPT ± 10 end_POSTSUBSCRIPT 0.39±0.16subscript0.39plus-or-minus0.160.39_{\pm 0.16}0.39 start_POSTSUBSCRIPT ± 0.16 end_POSTSUBSCRIPT -2.31
DiT-S [23] 32.7 M 12.45±2.8subscript12.45plus-or-minus2.812.45_{\pm 2.8}12.45 start_POSTSUBSCRIPT ± 2.8 end_POSTSUBSCRIPT 0.31±0.09subscript0.31plus-or-minus0.090.31_{\pm 0.09}0.31 start_POSTSUBSCRIPT ± 0.09 end_POSTSUBSCRIPT 14.5514.55-14.55- 14.55 62.3±5.9subscript62.3plus-or-minus5.962.3_{\pm 5.9}62.3 start_POSTSUBSCRIPT ± 5.9 end_POSTSUBSCRIPT 0.37±0.14subscript0.37plus-or-minus0.140.37_{\pm 0.14}0.37 start_POSTSUBSCRIPT ± 0.14 end_POSTSUBSCRIPT -0.44
GeCA-S (ours) 13.3 M 7.42±1.6subscript7.42plus-or-minus1.6\textbf{7.42}_{\pm\textbf{1.6}}7.42 start_POSTSUBSCRIPT ± 1.6 end_POSTSUBSCRIPT 0.39±0.11subscript0.39plus-or-minus0.11\textbf{0.39}_{\pm\textbf{0.11}}0.39 start_POSTSUBSCRIPT ± 0.11 end_POSTSUBSCRIPT 2.02 49.1±8.0subscript49.1plus-or-minus8.0\textbf{49.1}_{\pm\textbf{8.0}}49.1 start_POSTSUBSCRIPT ± 8.0 end_POSTSUBSCRIPT 0.53±0.16subscript0.53plus-or-minus0.16\textbf{0.53}_{\pm\textbf{0.16}}0.53 start_POSTSUBSCRIPT ± 0.16 end_POSTSUBSCRIPT 0.34

Baselines. Compared to previous NCA approaches [22, 12] which exhibited suboptimal performance and did not compare with SOTA generative benchmarks, we compare our GeCA against DiT [23], state-of-the-art diffusion transformers, as well as the U-Net-based diffusion models from LDM [24], modifying the label embedding to support multi-label OCT generation. Training and inference for all baseline models adhere to the same hyperparameters with Classifier Free Guidance (CFG) [9] to facilitate conditional generation on downstream tasks. For the DiT, we report DiT-S, with an optimal patch size of 2. Given our GeCA trains a single DiT layer, we take M=12𝑀12M=12italic_M = 12 equivalent to the number of layers in DiT-S; See Appendix for details.

Implementation Details. For all methods, generation is conducted in the latent space akin to LDM [24] with an output resolution of 256x256. Training acceleration for all methods is done with mixed-precision. We utilize a batch size of 128 and train all models for 14,000 epochs until convergence. For the downstream classification task, ResNet-34 is utilized with Adam optimizer.

Refer to caption
Figure 4: Qualitative examples for the Fundus and OCT datasets are provided; images are downsampled for visualization purposes, with high-resolution versions available in the supplementary material.
Table 2: Performance results on our in-house multi-label OCT dataset, employing a five-fold cross-validation approach at the patient level. Each fold trains a separate diffusion model to generate synthetic data for the downstream classification task. All downstream experiments use ResNet34 as the backbone. We follow prior works [30, 14] for eye-level performance evaluation, considering the multiple scans per eye in our dataset. The F1sen/pesubscript1𝑠𝑒𝑛𝑝𝑒1_{sen/pe}1 start_POSTSUBSCRIPT italic_s italic_e italic_n / italic_p italic_e end_POSTSUBSCRIPT quantifies the harmonic mean of Sensitivity (Sen.) and Specificity (Spe.). (****) denote statistical significance with a p-value less than 0.0001. All reported metrics are macro-averaged.
Synthetic Data Sen. Spe. AUC F1 F1sen/pesubscript1𝑠𝑒𝑛𝑝𝑒1_{sen/pe}1 start_POSTSUBSCRIPT italic_s italic_e italic_n / italic_p italic_e end_POSTSUBSCRIPT mAP p<𝑝absentp<italic_p <
Baseline (Geometric Aug) 54.66±1.53 96.50±0.16 92.47±0.85 55.47±0.99 60.80±1.49 68.85±1.41 -
Baseline w/o Aug. 48.34±1.45 96.39±0.20 89.99±0.82 54.56±1.77 50.07±0.89 64.58±1.19 **
LDM-B [24] 58.83±1.90 96.12±0.29 91.22±0.74 59.65±3.19 67.74±2.97 70.49±2.64 **
DiT-S [23] 59.25±4.54 95.87±0.37 91.80±1.74 59.11±2.57 67.13±4.87 69.89±3.34 ***
GeCA-S (ours) 59.95±5.32 96.38±0.40 92.74±2.21 61.62±3.93 68.38±4.61 73.28±5.58 ****
Refer to caption
(a) Public Fundus dataset.
Refer to caption
(b) OCT-ML.
Figure 5: Ablation of the proposed Gene Heredity Guidance (GHG).

Generative Modeling Evaluation. Tab. 1 presents the quantitative results to assess the quality of the generated samples. Noting the limitations of the Fréchet Inception Distance (FID) score observed in prior works [19, 10], we employ the Kernel Inception Distance (KID) for fidelity due to its sensitivity to dataset size [10, 1]. Additionally, we report the the perceptual LPIPS diversity [34] to measure the image variability. Finally, we present the generalization gap (GG) as quantified by the Feature Likelihood Divergence (FLD) [10], encapsulating the triplet novelty (different from the training samples), fidelity, and diversity of the synthetic samples. Overall, our GeCA demonstrates superior image quality, both quantitatively and qualitatively, as depicted in Fig. 4. We show samples from the high-resolution GeCA model in Fig. 1 and the appendix.

Retinal Disease Classification. Tab. 2 presents the 11 multi-label classification results on OCT-ML expanded by synthetic data via generative modeling discussed in Sec. 2.4. All generative models remarkably improved the performance across various metrics. Notably, expanding the training dataset with our proposed GeCA achieved the highest mean average precision (mAP of 73.28%). GeCA significantly surpass the baseline with geometric augmentation by 4.43% in mAP and 7.58% in the harmonic mean of Sensitivity and Specificity (F1sen/pesubscript1𝑠𝑒𝑛𝑝𝑒1_{sen/pe}1 start_POSTSUBSCRIPT italic_s italic_e italic_n / italic_p italic_e end_POSTSUBSCRIPT). Furthermore, in terms of the traditional F1-score, which evaluates precision and recall, GeCA achieved a significant 6.15% gain over the baseline. Despite being significantly more parameter-efficient, requiring only 40% of the parameters compared to the SOTA DiT-S [23], GeCA still manages to surpass it by 3.39% in mAP. Furthermore, GeCA not only exceeds the performance of the leading baseline, LDM-B [24], by 2.79% in mAP, but it also secures the highest degree of statistical significance (****). These results highlight GeCA’s very promising performance in the realm of generative modeling.

GHG Ablation.  Fig. 5 reveals the impact of Gene Heredity Guidance (GHG), introduced in Sec. 2.3, on two datasets. On the Fundus dataset, without inheritance, the model yields a moderate KID of 19.37, lacking the benefits of long-range dependencies. Inheriting Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT alone drastically impairs performance, spiking the KID to 44.22, suggesting that inheriting Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT propagates noise. Conversely, inheriting both Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT and hidden states Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT partially mitigates this effect, reducing the KID to 12.84. Optimal performance is observed when only Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT is inherited, achieving the lowest KID of 7.42. In contrast to Coutsuperscript𝐶𝑜𝑢𝑡C^{out}italic_C start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT, whose primary function is to predict noise, inheriting Chsuperscript𝐶C^{h}italic_C start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT facilitates the propagation of long-range dependencies, capturing the global context across the image.

4 Conclusion

We present GeCA, an innovative model outperforming current image generation benchmarks through neural cellular automata, demonstrated on challenging multi-label OCT classification. Future directions include broadening GeCA’s validation across various domains and exploiting its unique capabilities, such as channel dimension selective sampling and temporal scheduling of its updates.

{credits}

4.0.1 Acknowledgements

This work was supported in part by the grants from Foshan HKUST Projects, Grant Nos. FSUST21-HKUST10E and FSUST21-HKUST11E and in part by Project of Hetao Shenzhen-Hong Kong Science and Technology Innovation Cooperation Zone (HZQB-KCZYB-2020083). Marawan Elbatel is supported by the Hong Kong PhD Fellowship Scheme (HKPFS) from the Hong Kong Research Grants Council (RGC), and by the Belt and Road Initiative from the HKSAR Government.

4.0.2 \discintname

The authors have no competing interests to declare that are relevant to the content of this article.

References

  • [1] Bińkowski, M., Sutherland, D.J., Arbel, M., Gretton, A.: Demystifying MMD GANs. In: International Conference on Learning Representations (2018)
  • [2] Chen, C.F., Panda, R., Fan, Q.: Regionvit: Regional-to-local attention for vision transformers. In: International Conference on Learning Representations (2022)
  • [3] Chowdary, G.J., Yin, Z.: Diffusion transformer u-net for medical image segmentation. In: Medical Image Computing and Computer Assisted Intervention – MICCAI 2023. pp. 622–631. Springer Nature Switzerland, Cham (2023)
  • [4] Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J., Kaiser, L.: Universal transformers. In: International Conference on Learning Representations (2019)
  • [5] Dosovitskiy, A., Beyer, L., et al.: An image is worth 16x16 words: Transformers for image recognition at scale. In: ICLR (2021)
  • [6] Frisch, Y., Fuchs, M., et al.: Synthesising rare cataract surgery samples with guided diffusion models. In: Medical Image Computing and Computer Assisted Intervention – MICCAI 2023. pp. 354–364. Springer Nature Switzerland, Cham (2023)
  • [7] Gu, A., Dao, T.: Mamba: Linear-time sequence modeling with selective state spaces. ArXiv abs/2312.00752 (2023)
  • [8] Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. In: Proceedings of the 34th International Conference on Neural Information Processing Systems. NIPS’20, Curran Associates Inc., Red Hook, NY, USA (2020)
  • [9] Ho, J., Salimans, T.: Classifier-free diffusion guidance. In: NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications (2021)
  • [10] Jiralerspong, M., Bose, J., Gemp, I., Qin, C., Bachrach, Y., Gidel, G.: Feature likelihood score: Evaluating the generalization of generative models using samples. In: Thirty-seventh Conference on Neural Information Processing Systems (2023)
  • [11] Kalkhof, J., González, C., Mukhopadhyay, A.: Med-nca: Robust and lightweight segmentation with neural cellular automata. In: International Conference on Information Processing in Medical Imaging. pp. 705–716. Springer (2023)
  • [12] Kalkhof, J., Kühn, A., Frisch, Y., Mukhopadhyay, A.: Frequency-time diffusion with neural cellular automata. ArXiv abs/2401.06291 (2024)
  • [13] Kalkhof, J., Mukhopadhyay, A.: M3d-nca: Robust 3d segmentation with built-in quality control. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 169–178. Springer (2023)
  • [14] Li, X., Zhou, Y., Wang, J., Lin, H., Zhao, J., Ding, D., Yu, W., Chen, Y.: Multi-modal multi-instance learning for retinal disease recognition. In: Proceedings of the 29th ACM International Conference on Multimedia. p. 2474–2482. MM ’21, Association for Computing Machinery, New York, NY, USA (2021)
  • [15] Li, Y., Zhang, R., et al.: Predicting systemic diseases in fundus images: systematic review of setting, reporting, bias, and models’ clinical availability in deep learning studies. Eye (Jan 2024)
  • [16] Liu, R., Wang, X., et al.: Deepdrid: Diabetic retinopathy—grading and image quality estimation challenge. Patterns p. 100512 (2022)
  • [17] Midena, E., Frizziero, L., et al.: Optical coherence tomography and color fundus photography in the screening of age-related macular degeneration: A comparative, population-based study. Plos one 15(8), e0237352 (2020)
  • [18] Mordvintsev, A., Randazzo, E., Niklasson, E., Levin, M.: Growing neural cellular automata. Distill (2020). https://doi.org/10.23915/distill.00023
  • [19] Naeem, M.F., Oh, S.J., Uh, Y., Choi, Y., Yoo, J.: Reliable fidelity and diversity metrics for generative models. In: III, H.D., Singh, A. (eds.) Proceedings of the 37th International Conference on Machine Learning. Proceedings of Machine Learning Research, vol. 119, pp. 7176–7185. PMLR (13–18 Jul 2020)
  • [20] Oh, H.J., Jeong, W.K.: Diffmix: Diffusion model-based data synthesis for nuclei segmentation and classification in imbalanced pathology image datasets. In: Medical Image Computing and Computer Assisted Intervention – MICCAI 2023. pp. 337–345. Springer Nature Switzerland, Cham (2023)
  • [21] Pajouheshgar, E., Xu, Y., Zhang, T., Süsstrunk, S.: Dynca: Real-time dynamic texture synthesis using neural cellular automata. CVPR pp. 20742–20751 (2022)
  • [22] Palm, R.B., Duque, M.G., Sudhakaran, S., Risi, S.: Variational neural cellular automata. In: International Conference on Learning Representations (2022)
  • [23] Peebles, W., Xie, S.: Scalable diffusion models with transformers. In: 2023 IEEE/CVF International Conference on Computer Vision (ICCV). pp. 4172–4182. IEEE Computer Society, Los Alamitos, CA, USA (oct 2023)
  • [24] Rombach, R., Blattmann, A., Lorenz, D., Esser, P., Ommer, B.: High-resolution image synthesis with latent diffusion models (2021)
  • [25] Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedical image segmentation. In: Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015. pp. 234–241. Springer International Publishing, Cham
  • [26] Sudhakaran, S., Najarro, E., Risi, S.: Goal-guided neural cellular automata: Learning to control self-organising systems. In: From Cells to Societies: Collective Learning across Scales (2022)
  • [27] Tesfaldet, M., Nowrouzezahrai, D., Pal, C.: Attention-based neural cellular automata. In: Oh, A.H., Agarwal, A., Belgrave, D., Cho, K. (eds.) Advances in Neural Information Processing Systems (2022)
  • [28] Tolstikhin, I., Houlsby, N., et al.: MLP-mixer: An all-MLP architecture for vision. In: Advances in Neural Information Processing Systems (2021)
  • [29] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., Polosukhin, I.: Attention is all you need. In: NeurIPS. p. 6000–6010. NIPS’17, Curran Associates Inc., Red Hook, NY, USA (2017)
  • [30] Wang, L., Dai, W., Jin, M., Ou, C., Li, X.: Fundus-enhanced disease-aware distillation model for retinal disease classification from oct images. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 639–648. Springer (2023)
  • [31] Yang, J., Shi, R., Wei, D., Liu, Z., Zhao, L., Ke, B., Pfister, H., Ni, B.: Medmnist v2-a large-scale lightweight benchmark for 2d and 3d biomedical image classification. Scientific Data 10(1),  41 (2023)
  • [32] Yang, Y., Fu, H., Aviles-Rivero, A.I., Schönlieb, C.B., Zhu, L.: Diffmic: Dual-guidance diffusion network for medical image classification. In: Medical Image Computing and Computer Assisted Intervention – MICCAI 2023. pp. 95–105. Springer Nature Switzerland, Cham (2023)
  • [33] Zhang, P., Dai, X., et al.: Multi-scale vision longformer: A new vision transformer for high-resolution image encoding. ICCV 2021
  • [34] Zhang, R., Isola, P., Efros, A.A., Shechtman, E., Wang, O.: The unreasonable effectiveness of deep features as a perceptual metric. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition pp. 586–595 (2018)
  • [35] Zhang, Y., Zhou, D., Hooi, B., Wang, K., Feng, J.: Expanding small-scale datasets with guided imagination. In: Thirty-seventh Conference on Neural Information Processing Systems (2023), https://openreview.net/forum?id=82HeVCqsfh
  • [36] Zhao, H., Li, H., Maurer-Stroh, S., Guo, Y., Deng, Q., Cheng, L.: Supervised segmentation of un-annotated retinal fundus images by synthesis. IEEE Transactions on Medical Imaging 38, 46–56 (2019)

Appendix for “GeCA”

Appendix for “GeCA”

Refer to caption
Figure 6: Diverging from conventional hierarchical models, GeCA offers denoising strength adjustment via M updates at each denoising step without the need for re-training. Future efforts will investigate M-scheduling techniques. Performance is demonstrated under T=250𝑇250T=250italic_T = 250. While speed concerns exist, GeCA’s promising performance and optimization prospects highlight its significance.
Table 3: Distribution of diseases across our in-house OCT images illustrating the uneven distribution of various ocular conditions within the dataset (class imbalance). All models trained on OCT images were subjected to rigorous validation using a 5-fold cross-validation process, with patient-wise splitting. Generative models are trained exclusively on the training set of each fold.
Diagnosis Normal dAMD wAMD DR CSC PED MEM FLD EXU CNV RVO Total
Count 278 160 145 502 95 133 196 613 573 138 34 1435
Refer to caption
Figure 7: High Resolution output visualization.