Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

Single-cell Curriculum Learning-based Deep Graph Embedding Clustering

Huifa Li Shanghai Key Laboratory of Trustworthy
Computing
East China Normal University
Shanghai, China
huifali@stu.ecnu.edu.cn
   Jie Fu Department of Computer Science
Stevens Institute of Technology
Hoboken, USA
jfu13@stevens.edu
   Xinpeng Ling Shanghai Key Laboratory of Trustworthy
Computing
East China Normal University
Shanghai, China
xpling@stu.ecnu.edu.cn
   Zhiyu Sun Shanghai Key Laboratory of Trustworthy
Computing
East China Normal University
Shanghai, China
51265902067@stu.ecnu.edu.cn
   Kuncan Wang Shanghai Key Laboratory of Trustworthy
Computing
East China Normal University
Shanghai, China
10204804424@stu.ecnu.edu.cn
   Zhili Chen * Corresponding author: Zhili Chen. Shanghai Key Laboratory of Trustworthy
Computing
East China Normal University
Shanghai, China
zhlchen@sei.ecnu.edu.cn
Abstract

The swift advancement of single-cell RNA sequencing (scRNA-seq) technologies enables the investigation of cellular-level tissue heterogeneity. Cell annotation significantly contributes to the extensive downstream analysis of scRNA-seq data. However, The analysis of scRNA-seq for biological inference presents challenges owing to its intricate and indeterminate data distribution, characterized by a substantial volume and a high frequency of dropout events. Furthermore, the quality of training samples varies greatly, and the performance of the popular scRNA-seq data clustering solution GNN could be harmed by two types of low-quality training nodes: 1) nodes on the boundary; 2) nodes that contribute little additional information to the graph. To address these problems, we propose a single-cell curriculum learning-based deep graph embedding clustering (scCLG). We first propose a Chebyshev graph convolutional autoencoder with multi-decoder (ChebAE) that combines three optimization objectives corresponding to three decoders, including topology reconstruction loss of cell graphs, zero-inflated negative binomial (ZINB) loss, and clustering loss, to learn cell-cell topology representation. Meanwhile, we employ a selective training strategy to train GNN based on the features and entropy of nodes and prune the difficult nodes based on the difficulty scores to keep the high-quality graph. Empirical results on a variety of gene expression datasets show that our model outperforms state-of-the-art methods.

Index Terms:
scRNA-seq Data, Graph Clustering, Curriculum Learning

I Introduction

The advent of single-cell RNA sequencing (scRNA-seq) technologies has enabled the measurement of gene expressions in a vast number of individual cells, offering the potential to deliver detailed and high-resolution understandings of the intricate cellular landscape. The analysis of scRNA-seq data plays a pivotal role in biomedical research, including identifying cell types and subtypes, studying developmental processes, investigating disease mechanisms, exploring immunological responses, and supporting drug development and personalized therapy. Cell annotation is the fundamental step in analyzing scRNA-seq data. In early research, various traditional clustering methods have been applied such as K-means, spectral clustering, hierarchical clustering and density-based clustering. However, scRNA-seq data are so sparse that most of the measurements are zeros. The traditional clustering algorithm often produces suboptimal results.

Several clustering methods have been developed to address these limitations. CIDR [1], MAGIC [2], and SAVER [3] have been developed to initially address the issue of missing values, commonly referred to as dropouts, followed by the clustering of the imputed data. Despite the benefits of imputation, these methods encounter challenges in capturing the intricate inherent structure of scRNA-seq data. Alternative strategies, such as SIMLR [4] and MPSSC [5], utilize multi-kernel spectral clustering to acquire robust similarity measures. Nevertheless, the computational complexity associated with generating the Laplacian matrix hinders their application to large-scale datasets. Additionally, these techniques fail to account for crucial attributes of transcriptional data, including zero inflation and over-dispersion.

In recent years, deep learning has shown excellent performance in the fields of image recognition and processing, speech recognition, recommendation systems, and autonomous driving [6, 7, 8, 9]. Some deep learning clustering methods have effectively emerged to model the high-dimensional and sparse nature of scRNA-seq data such as scziDesk [10], scDCC [11], and scDeepCluster [12]. These models implement auto-encoding architectures. However, they often ignore the cell-cell relationships, which can make the clustering task more challenging. Recently, the emerging graph neural network (GNN) has deconvoluted node relationships in a graph through neighbor information propagation in a deep learning architecture. scGNN [13] and scGAE [14] combine deep autoencoder and graph clustering algorithms to preserve the neighborhood relationships. However, their training strategies largely ignore the importance of different nodes in the graph and how their orders can affect the optimization status, which may result in suboptimal performance of the graph learning models.

In particular, curriculum learning (CL) is an effective training strategy for gradually guiding model learning in tasks with obvious difficulty levels [15]. Curriculum learning has applications in natural language processing, computer vision, and other fields that require processing complex data. However, research on scRNA-seq data clustering is still blank, and the impact of traditional curriculum learning methods retaining all data on removing difficult samples on the model has not been explored yet.

Motivated by the above observations, we propose here a single-cell curriculum learning-based deep graph embedding clustering name scCLG, which simultaneously learns cell-cell topology representations and identifies cell clusters from an autoencoder following an easy-to-hard pattern (Fig. 1). We first propose a Chebyshev graph convolutional autoencoder with multi-decoder (ChebAE) to preserve the topological structure of the cells in the low-dimensional latent space (Fig. 2). Then, with the help of feature information, we design a hierarchical difficulty measurer, in which two difficulty measurers from local and global perspectives are proposed to measure the difficulty of training nodes. The local difficulty measurer computes local feature distribution to identify difficult nodes because their neighbors have diverse labels; the global difficulty measurer identifies difficult nodes by calculating the node entropy and graph entropy. After that, the most difficult nodes will be pruned to keep the high-quality graph. Finally, scCLG can combine three optimization objectives corresponding to three decoders, including topology reconstruction loss of cell graphs, zero-inflated negative binomial (ZINB) loss, and clustering loss, to learn cell-cell topology representation, optimize cell clustering label allocation, and produce superior clustering results.

The main contributions of our work are summarized below:

  • We propose a single-cell curriculum learning-based deep graph embedding clustering called scCLG, which integrates the meaningful training order into a Chebyshev graph convolutional autoencoder to capture the global probabilistic structure of data.

  • scCLG constructs a cell graph and uses a Chebyshev graph convolutional autoencoder to collectively preserve the topological structural information and the cell-cell relationships in scRNA-seq data.

  • To the best of our knowledge, this is the first article to incorporate curriculum learning with data pruning into a graph convolutional autoencoder to model highly sparse and overdispersed scRNA-seq data.

  • We evaluate our model alongside state-of-the-art competitive methods on 7 real scRNA-seq datasets. The results demonstrate that scCLG outperforms all of the baseline methods.

II Related Work

scRNA-seq clustering. With the advent of deep learning (DL), more recent works have utilized deep neural networks to automatically extract features from scRNA-seq data for enhancing feature representation. scDC [12] simultaneously learns to feature representation and clustering via explicit modeling of scRNA-seq data generation. In another work, scziDesk [10] combines deep learning with a denoising autoencoder to characterize scRNA-seq data while proposing a soft self-training K-means algorithm to cluster the cell population in the learned latent space. scDCC [11] integrates prior knowledge to loss function with pairwise constraints to scRNA-seq. The high-order representation and topological relations could be naturally learned by the graph neural network. scGNN [13] introduces a multi-modal autoencoder framework. This framework formulates and aggregates cell–cell relationships with graph neural networks and models heterogeneous gene expression patterns using a left-truncated mixture Gaussian model. scGAE [14] builds a cell graph and uses a multitask‑oriented graph autoencoder to preserve topological structure information and feature information in scRNA‑seq data simultaneously. However, the above clustering methods overlook the learning difficulty of different samples or nodes.

Curriculum learning. Curriculum learning, which mimics the human learning process of learning data samples in a meaningful order, aims to enhance the machine learning models by using a designed training curriculum, typically following an easy-to-hard pattern [15]. The CL framework consists of two components: a difficulty measurer which measures the difficulty of samples and a training scheduler which arranges the ordered samples into training. The key to CL is how to define the promising measurer. SPCL [16] takes into account both prior knowledge known before training and the learning progress during training. CLNode [17] measures the difficulty of training nodes based on the label information. SMMCL [18] assumes that different unlabeled samples have different difficulty levels for propagation, so it should follow an easy-to-hard sequence with an updated curriculum for label propagation. scSPaC [19] utilizes an advanced NMF for scRNA-seq data clustering based on soft self-paced learning, which gradually adds cells from simple to complex to our model until the model converges. However, the above CL methods don’t utilize the structural information of nodes in graph neural networks and don’t consider the impact of difficult nodes on the graph.

Refer to caption
Figure 1: Framework of scCLG. (A) Pre-training: pretraining the proposed ChebAE with adjacency matrix decoder and ZINB decoder. Then calculate node difficulty using a hierarchical difficulty measurer and prune the data. (B) Formal training: using all three decoders to optimize the model in more detail from easy to hard pattern with pruned data.

III PRELIMINARIES

In this section, we first introduce some notations, symbols, and necessary background. Then we present the Chebyshev graph convolution.

III-A Notations

Let 𝒢=(𝒱,,𝐗)𝒢𝒱𝐗\mathcal{G}=(\mathcal{V},\mathcal{E},\mathbf{X})caligraphic_G = ( caligraphic_V , caligraphic_E , bold_X ) be an undirected cell graph, where 𝒱={v1,v2,\mathcal{V}=\{v_{1},v_{2},caligraphic_V = { italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ,vnc}\dots,v_{n_{c}}\}… , italic_v start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT } is a set of ncsubscript𝑛𝑐n_{c}italic_n start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT nodes associated with different cells; eijsubscript𝑒𝑖𝑗e_{ij}\in\mathcal{E}italic_e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ∈ caligraphic_E specifies the existence of an edge between the ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT and jthsuperscript𝑗𝑡j^{th}italic_j start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT nodes; and 𝐗𝐗\mathbf{X}bold_X is the node feature matrix and xijsubscript𝑥𝑖𝑗x_{ij}italic_x start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT element is the count of the jthsuperscript𝑗𝑡j^{th}italic_j start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT gene in the ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT cell. Let 𝐀nc×nc𝐀superscriptsubscript𝑛𝑐subscript𝑛𝑐\mathbf{A}\in\mathbb{R}^{n_{c}\times n_{c}}bold_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be the adjacency matrix of 𝒢𝒢\mathcal{G}caligraphic_G, where aij=1subscript𝑎𝑖𝑗1a_{ij}=1italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = 1 if visubscript𝑣𝑖v_{i}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and vjsubscript𝑣𝑗v_{j}italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are connected, otherwise aijsubscript𝑎𝑖𝑗a_{ij}italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT is set equal to zero. The graph Laplacian 𝐋=𝐃𝐀𝒩×𝒩𝐋𝐃𝐀superscript𝒩𝒩\mathbf{L}=\mathbf{D}-\mathbf{A}\in\mathbb{R}^{\mathcal{N}\times\mathcal{N}}bold_L = bold_D - bold_A ∈ blackboard_R start_POSTSUPERSCRIPT caligraphic_N × caligraphic_N end_POSTSUPERSCRIPT, where I𝒩subscript𝐼𝒩I_{\mathcal{N}}italic_I start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT is the identity matrix, and 𝐃𝒩×𝒩𝐃superscript𝒩𝒩\mathbf{D}\in\mathcal{R}^{\mathcal{N}\times\mathcal{N}}bold_D ∈ caligraphic_R start_POSTSUPERSCRIPT caligraphic_N × caligraphic_N end_POSTSUPERSCRIPT is the diagonal degree matrix with 𝐃ii=j𝐀ijsubscript𝐃𝑖𝑖subscript𝑗subscript𝐀𝑖𝑗\mathbf{D}_{ii}=\sum_{j}\mathbf{A}_{ij}bold_D start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT. KNN algorithm is employed to construct the cell graph and each node in the graph represents a cell [20].

III-B Chebyshev Graph Convolution

Chebyshev graph convolution (ChebConv) is a variant of graph convolutional networks that uses Chebyshev polynomials to approximate the feature decomposition of graph Laplacian matrices, thereby achieving convolution operations on graph data. The theoretical foundation of ChebConv is graph signal processing and spectrogram theory, which introduces the concept of graph signal processing into graph convolutional networks. The ChebConv layer is defined as follows:

𝐇=k=1K𝐙(k)Θ(k)𝐇superscriptsubscript𝑘1𝐾superscript𝐙𝑘superscriptΘ𝑘\displaystyle\mathbf{H}=\sum_{k=1}^{K}\mathbf{Z}^{(k)}\cdot\Theta^{(k)}bold_H = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_Z start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ⋅ roman_Θ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT (1)

where K𝐾Kitalic_K represents the order of Chebyshev polynomials used to approximate graph convolution kernels. ΘΘ\Thetaroman_Θ is the layer’s trainable parameter and Z(k)superscript𝑍𝑘Z^{(k)}italic_Z start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT is computed recursively by:

𝐙(1)superscript𝐙1\displaystyle\mathbf{Z}^{(1)}bold_Z start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT =𝐗absent𝐗\displaystyle=\mathbf{X}= bold_X (2)
𝐙(2)superscript𝐙2\displaystyle\mathbf{Z}^{(2)}bold_Z start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT =𝐋^𝐗absent^𝐋𝐗\displaystyle=\hat{\mathbf{L}}\cdot\mathbf{X}= over^ start_ARG bold_L end_ARG ⋅ bold_X (3)
𝐙(k)superscript𝐙𝑘\displaystyle\mathbf{Z}^{(k)}bold_Z start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT =2𝐋^𝐙(k1)𝐙(k2)absent2^𝐋superscript𝐙𝑘1superscript𝐙𝑘2\displaystyle=2\cdot\hat{\mathbf{L}}\cdot\mathbf{Z}^{(k-1)}-\mathbf{Z}^{(k-2)}= 2 ⋅ over^ start_ARG bold_L end_ARG ⋅ bold_Z start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT - bold_Z start_POSTSUPERSCRIPT ( italic_k - 2 ) end_POSTSUPERSCRIPT (4)

where 𝐋^^𝐋\hat{\mathbf{L}}over^ start_ARG bold_L end_ARG denotes the scaled and normalized Laplacian 2𝐋λmax𝐈2𝐋subscript𝜆𝑚𝑎𝑥𝐈\frac{2\mathbf{L}}{\lambda_{max}}-\mathbf{I}divide start_ARG 2 bold_L end_ARG start_ARG italic_λ start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_ARG - bold_I. λmaxsubscript𝜆𝑚𝑎𝑥\lambda_{max}italic_λ start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT is the largest eigenvalue of 𝐋𝐋\mathbf{L}bold_L and 𝐈𝐈\mathbf{I}bold_I is the identity matrix.

Compared with basic GCN, ChebConv effectively reduces the model’s parameter count and computational complexity by transforming graph convolution operations into approximations of Chebyshev polynomials, while maintaining its ability to capture graph structures.

IV Proposed Approach

In this section, we firstly present our idea of multi-decoder ChebConv graph autoencoder. Secondly, we introduce how the scCLG model parameters can be learned using a meaningful sample order. Finally, we elaborate the proposed scRNA-seq data clustering algorithm by combining the above two points.

IV-A Multi-Decoder ChebConv Graph Autoencoder

As shown in Fig. 2, to capture the cell graph structure and node relationships, we developed a variant of the graph convolution autoencoder that uses a stacked topology Chebyshev graph convolutional network as the graph encoder. We use three different decoders to map the encoded compressed vector from different perspectives and jointly optimize the modeling ability of the autoencoder. The gene expression matrix 𝐗𝐗\mathbf{X}bold_X and normalized adjacency matrix 𝐀𝐀\mathbf{A}bold_A are used inputs. Through the graph encoder, the feature dimension of each node will be compressed to a smaller size, and the compressed vector features will be decoded by three decoders: adjacency matrix decoder (DecA𝐷𝑒subscript𝑐𝐴Dec_{A}italic_D italic_e italic_c start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT), ZINB decoder (Deczinb𝐷𝑒subscript𝑐𝑧𝑖𝑛𝑏Dec_{zinb}italic_D italic_e italic_c start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT), and clustering decoder (Deccls𝐷𝑒subscript𝑐𝑐𝑙𝑠Dec_{cls}italic_D italic_e italic_c start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT). These decoders share encoder parameters to decompose an optimization objective into three optimization objectives for better capturing the cell-cell relationship:

L=Lrec+Lzinb+Lcls𝐿subscript𝐿𝑟𝑒𝑐subscript𝐿𝑧𝑖𝑛𝑏subscript𝐿𝑐𝑙𝑠\displaystyle L=L_{rec}+L_{zinb}+L_{cls}italic_L = italic_L start_POSTSUBSCRIPT italic_r italic_e italic_c end_POSTSUBSCRIPT + italic_L start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT + italic_L start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT (5)
Refer to caption
Figure 2: The model architecture of multi-decoder ChebAE. ChebAE integrates three decoders adjacency matrix decoder, ZINB decoder, and a clustering decoder to learn the low-dimensional latent representation.

More detailed optimization information about Lrecsubscript𝐿𝑟𝑒𝑐L_{rec}italic_L start_POSTSUBSCRIPT italic_r italic_e italic_c end_POSTSUBSCRIPT, Lzinbsubscript𝐿𝑧𝑖𝑛𝑏L_{zinb}italic_L start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT and Lclssubscript𝐿𝑐𝑙𝑠L_{cls}italic_L start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT is shown below.

IV-A1 Reconstruction Loss

Given that the majority of the structure and information inherent in the scRNA-seq data 𝐗𝐗\mathbf{X}bold_X is conserved within the latent embedded representation generated by the scCLG encoder. The adjacency matrix decoder of the graph autoencoder can be defined as the inner product between the latent embedding:

𝐙𝐙\displaystyle\mathbf{Z}bold_Z =fE(𝐗)absentsubscript𝑓𝐸𝐗\displaystyle=f_{E}(\mathbf{X})= italic_f start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( bold_X ) (6)
𝐀recsubscript𝐀𝑟𝑒𝑐\displaystyle\mathbf{A}_{rec}bold_A start_POSTSUBSCRIPT italic_r italic_e italic_c end_POSTSUBSCRIPT =σ(𝐙T𝐙)absent𝜎superscript𝐙𝑇𝐙\displaystyle=\sigma(\mathbf{Z}^{T}\mathbf{Z})= italic_σ ( bold_Z start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_Z ) (7)

where, fEsubscript𝑓𝐸f_{E}italic_f start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT represents the scCLG encoder function; 𝐀recsubscript𝐀𝑟𝑒𝑐\mathbf{A}_{rec}bold_A start_POSTSUBSCRIPT italic_r italic_e italic_c end_POSTSUBSCRIPT is the reconstructed adjacency matrix. Therefore, the reconstruction loss of 𝐀𝐀\mathbf{A}bold_A and 𝐀recsubscript𝐀𝑟𝑒𝑐\mathbf{A}_{rec}bold_A start_POSTSUBSCRIPT italic_r italic_e italic_c end_POSTSUBSCRIPT should be minimized in the learning process as below:

Lrec=𝐀𝐀rec22subscript𝐿𝑟𝑒𝑐superscriptsubscriptnorm𝐀subscript𝐀𝑟𝑒𝑐22\displaystyle L_{rec}=||\mathbf{A}-\mathbf{A}_{rec}||_{2}^{2}italic_L start_POSTSUBSCRIPT italic_r italic_e italic_c end_POSTSUBSCRIPT = | | bold_A - bold_A start_POSTSUBSCRIPT italic_r italic_e italic_c end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (8)

IV-A2 ZINB Loss

In order to more effectively capture the structure of scRNA-seq data by decoding from the latent embedded representation 𝐙𝐙\mathbf{Z}bold_Z, we integrate the ZINB model into a Chebyshev graph convolutional autoencoder to capture the global probability structure of the scRNA-seq data. Based on this foundation, we propose to apply the ZINB distribution model to simulate the data distribution to capture the characters of scRNA-seq data as follows:

NB(x|μ,θ)NBconditional𝑥𝜇𝜃\displaystyle\mathrm{NB}(x|\mu,\theta)roman_NB ( italic_x | italic_μ , italic_θ ) =Γ(x+θ)x!Γ(θ)(θθ+μ)θ(μθ+μ)xabsentΓ𝑥𝜃𝑥Γ𝜃superscript𝜃𝜃𝜇𝜃superscript𝜇𝜃𝜇𝑥\displaystyle=\frac{\Gamma(x+\theta)}{x!\Gamma(\theta)}(\frac{\theta}{\theta+% \mu})^{\theta}(\frac{\mu}{\theta+\mu})^{x}= divide start_ARG roman_Γ ( italic_x + italic_θ ) end_ARG start_ARG italic_x ! roman_Γ ( italic_θ ) end_ARG ( divide start_ARG italic_θ end_ARG start_ARG italic_θ + italic_μ end_ARG ) start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ( divide start_ARG italic_μ end_ARG start_ARG italic_θ + italic_μ end_ARG ) start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT (9)
ZINB(x|π,μ,θ)ZINBconditional𝑥𝜋𝜇𝜃\displaystyle\mathrm{ZINB}(x|\pi,\mu,\theta)roman_ZINB ( italic_x | italic_π , italic_μ , italic_θ ) =πδ0(x)+(1π)NB(x|μ,θ)absent𝜋subscript𝛿0𝑥1𝜋NBconditional𝑥𝜇𝜃\displaystyle=\pi\delta_{0}(x)+(1-\pi)\mathrm{NB}(x|\mu,\theta)= italic_π italic_δ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) + ( 1 - italic_π ) roman_NB ( italic_x | italic_μ , italic_θ ) (10)

where μ𝜇\muitalic_μ and θ𝜃\thetaitalic_θ are the mean and dispersion in the negative binomial distribution, respectively. π𝜋\piitalic_π is the weight of the point mass at zero. The proportion θθ+μ𝜃𝜃𝜇\frac{\theta}{\theta+\mu}divide start_ARG italic_θ end_ARG start_ARG italic_θ + italic_μ end_ARG replaces the probability p. After that, to model the ZINB distribution, the decoder network has three output layers to compute the three sets of parameters. The estimated parameters can be defined as follows:

π^^𝜋\displaystyle\hat{\pi}over^ start_ARG italic_π end_ARG =sigmoid(𝐖πfDzinb(𝐙))absentsigmoidsubscript𝐖𝜋subscript𝑓subscript𝐷𝑧𝑖𝑛𝑏𝐙\displaystyle=\text{sigmoid}(\mathbf{W}_{\pi}f_{D_{zinb}}(\mathbf{Z}))= sigmoid ( bold_W start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_Z ) ) (11)
μ^^𝜇\displaystyle\hat{\mu}over^ start_ARG italic_μ end_ARG =exp(𝐖μfDzinb(𝐙))absent𝑒𝑥𝑝subscript𝐖𝜇subscript𝑓subscript𝐷𝑧𝑖𝑛𝑏𝐙\displaystyle=exp(\mathbf{W}_{\mu}f_{D_{zinb}}(\mathbf{Z}))= italic_e italic_x italic_p ( bold_W start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_Z ) ) (12)
θ^^𝜃\displaystyle\hat{\theta}over^ start_ARG italic_θ end_ARG =exp(𝐖θfDzinb(𝐙))absent𝑒𝑥𝑝subscript𝐖𝜃subscript𝑓subscript𝐷𝑧𝑖𝑛𝑏𝐙\displaystyle=exp(\mathbf{W}_{\theta}f_{D_{zinb}}(\mathbf{Z}))= italic_e italic_x italic_p ( bold_W start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_Z ) ) (13)

where fDzinbsubscript𝑓subscript𝐷𝑧𝑖𝑛𝑏f_{D_{zinb}}italic_f start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT is a three-layer fully connected neural network with hidden layers of 128, 256 and 512 nodes. 𝐖𝐖\mathbf{W}bold_W represents the learned weights parameter matrices. π^,μ^^𝜋^𝜇\hat{\pi},\hat{\mu}over^ start_ARG italic_π end_ARG , over^ start_ARG italic_μ end_ARG and θ^^𝜃\hat{\theta}over^ start_ARG italic_θ end_ARG are parameters denoting the estimations of π,μ𝜋𝜇\pi,\muitalic_π , italic_μ and θ𝜃\thetaitalic_θ, respectively. The selection of the activation function depends on the range and definition of the parameters. In terms of the parameter π𝜋\piitalic_π, the suitable activation function for it is sigmoid because the interval of π𝜋\piitalic_π is between 0 and 1. Due to the non-negative value of the mean μ𝜇\muitalic_μ and dispersion θ𝜃\thetaitalic_θ, we choose the exponential activation function for them. The negative log-likelihood of the ZINB distribution can be used as the reconstruction loss function of the original data 𝐗𝐗\mathbf{X}bold_X, which can be defined as below:

LZINB=log(ZINB(𝐗|π,μ,θ))subscript𝐿ZINBlogZINBconditional𝐗𝜋𝜇𝜃\displaystyle L_{\text{ZINB}}=-\text{log}(\text{ZINB}(\mathbf{X}|\pi,\mu,% \theta))italic_L start_POSTSUBSCRIPT ZINB end_POSTSUBSCRIPT = - log ( ZINB ( bold_X | italic_π , italic_μ , italic_θ ) ) (14)

IV-A3 Clustering Loss

scRNA-seq clustering clustering as an unsupervised learning task, lacks guidance from labels, which makes it difficult to capture effective optimization signals during the training phase. To overcome this challenge, we apply a clustering module to guide the algorithm to adjust the cluster centers to ensure that the distribution of samples within each cluster is as consistent as possible while minimizing inter-cluster differences. The objective takes the form of Kullback–Leibler (KL) divergence and is formulated as follows:

Lcls=KL(P||Q)=iupiulogpiuqiu\displaystyle L_{cls}=KL(P||Q)=\sum_{i}\sum_{u}p_{iu}\text{log}\frac{p_{iu}}{q% _{iu}}italic_L start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT = italic_K italic_L ( italic_P | | italic_Q ) = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_i italic_u end_POSTSUBSCRIPT log divide start_ARG italic_p start_POSTSUBSCRIPT italic_i italic_u end_POSTSUBSCRIPT end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_i italic_u end_POSTSUBSCRIPT end_ARG (15)

where qiusubscript𝑞𝑖𝑢q_{iu}italic_q start_POSTSUBSCRIPT italic_i italic_u end_POSTSUBSCRIPT is the soft label of the embedding node zisubscript𝑧𝑖z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT which is defined as the similarity between zisubscript𝑧𝑖z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and cluster centre μusubscript𝜇𝑢\mu_{u}italic_μ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT measured by Student’s t-distribution. This can be described as follows:

qij=(1+ziμj2)1j(1+ziμj2)1subscript𝑞𝑖𝑗superscript1superscriptnormsubscript𝑧𝑖subscript𝜇𝑗21subscriptsuperscript𝑗superscript1superscriptnormsubscript𝑧𝑖subscript𝜇superscript𝑗21\displaystyle q_{ij}=\frac{(1+\|z_{i}-\mu_{j}\|^{2})^{-1}}{\sum_{j^{{}^{\prime% }}}(1+\|z_{i}-\mu_{j^{{}^{\prime}}}\|^{2})^{-1}}italic_q start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = divide start_ARG ( 1 + ∥ italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( 1 + ∥ italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG (16)

Meanwhile, piusubscript𝑝𝑖𝑢p_{iu}italic_p start_POSTSUBSCRIPT italic_i italic_u end_POSTSUBSCRIPT is the auxiliary target distribution, which puts more emphasis on the similar data points assigned with high confidence on the basis of qiusubscript𝑞𝑖𝑢q_{iu}italic_q start_POSTSUBSCRIPT italic_i italic_u end_POSTSUBSCRIPT, as below:

pij=qij2/Σjqijj(qij2/jqij)subscript𝑝𝑖𝑗superscriptsubscript𝑞𝑖𝑗2subscriptΣ𝑗subscript𝑞𝑖𝑗subscriptsuperscript𝑗superscriptsubscript𝑞𝑖superscript𝑗2subscriptsuperscript𝑗subscript𝑞𝑖superscript𝑗\displaystyle p_{ij}=\frac{q_{ij}^{2}/\Sigma_{j}q_{ij}}{\sum_{j^{\prime}}(q_{% ij^{\prime}}^{2}/\sum_{j^{\prime}}q_{ij^{\prime}})}italic_p start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = divide start_ARG italic_q start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / roman_Σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG (17)

Since the target distribution P𝑃Pitalic_P is defined based on Q𝑄Qitalic_Q, the embedding learning of Q𝑄Qitalic_Q is supervised in a self-optimizing way to enable it to be close to the target distribution P𝑃Pitalic_P.

IV-B Curriculum Learning with Data Pruning

In this subsection, we first describe the proposed difficulty measurement method from both local and global perspectives and assign a difficulty score to each cell. Based on the difficulty score, we investigate the impact of nodes with higher difficulty on model optimization.

IV-B1 Hierarchical Difficulty Measurer

Our Hierarchical Difficulty Measurer consists of two difficulty measures from different perspectives. In this section, we present the definition of two difficulty measures and how to calculate them.

Local Difficulty Measurer. We introduce how to identify difficult nodes from a local perspective. Nodes located at the boundaries of multiple classes may reside in transitional regions within the feature space, leading to less distinct or consistent feature representations, thereby increasing the difficulty of classification. The first type of difficult node should have diverse neighbors that belong to multiple classes. Intuitively, features of nodes within the same class tend to be more similar. This is due to the influence of neighboring node features, resulting in nodes with similar connectivity patterns exhibiting comparable feature representations. In order to identify these difficult nodes, we calculate the diversity of the neighborhood’s features:

Dlocal(u)subscript𝐷𝑙𝑜𝑐𝑎𝑙𝑢\displaystyle D_{local}(u)italic_D start_POSTSUBSCRIPT italic_l italic_o italic_c italic_a italic_l end_POSTSUBSCRIPT ( italic_u ) =v𝒩(u)S(u,v)absentsubscript𝑣𝒩𝑢𝑆𝑢𝑣\displaystyle=\sum_{v\in\mathcal{N}(u)}S(u,v)= ∑ start_POSTSUBSCRIPT italic_v ∈ caligraphic_N ( italic_u ) end_POSTSUBSCRIPT italic_S ( italic_u , italic_v ) (18)
S(u,v)𝑆𝑢𝑣\displaystyle S(u,v)italic_S ( italic_u , italic_v ) =uvuvabsent𝑢𝑣norm𝑢norm𝑣\displaystyle=\frac{u\cdot v}{||u||\cdot||v||}= divide start_ARG italic_u ⋅ italic_v end_ARG start_ARG | | italic_u | | ⋅ | | italic_v | | end_ARG (19)

where S(u,v)𝑆𝑢𝑣S(u,v)italic_S ( italic_u , italic_v ) denotes the similarity between cell u𝑢uitalic_u and cell v𝑣vitalic_v. A larger Dlocal(u)subscript𝐷𝑙𝑜𝑐𝑎𝑙𝑢D_{local}(u)italic_D start_POSTSUBSCRIPT italic_l italic_o italic_c italic_a italic_l end_POSTSUBSCRIPT ( italic_u ) indicates a more diverse neighborhood. As a result, during neighborhood aggregation, these nodes aggregate neighbors’ features to get an unclear representation, making them difficult for GNNs to learn. By paying less attention to these difficult nodes, scCLG learns more useful information and effectively improves the accuracy of backbone GNNs.

Global Difficulty Measurer. Then we introduce how to identify difficult nodes from a global perspective. Entropy plays a pivotal role in feature selection as a metric from information theory used to quantify uncertainty. In the process of feature selection, we leverage entropy to assess a feature’s contribution to the target variable. When a feature better distinguishes between different categories of the target variable, its entropy value tends to be relatively low, signifying that it provides more information and reduces overall uncertainty. Consequently, in feature selection, lower entropy values indicate features that offer greater discriminatory power, aiding in the differentiation of target variable categories. We assume nodes that have lower entropy have fewer contributions to the graph. Therefore, this type of node is difficult to classify. Inspired by Entropy Variation [21], We consider the node contribution as the variation of network entropy before and after its removal.

For a node visubscript𝑣𝑖v_{i}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in graph 𝒢𝒢\mathcal{G}caligraphic_G, we define p(v)𝑝𝑣p(v)italic_p ( italic_v ) as probabilities:

p(v)𝑝𝑣\displaystyle p(v)italic_p ( italic_v ) =D(v)u𝒱D(u)absent𝐷𝑣subscript𝑢𝒱𝐷𝑢\displaystyle=\frac{D(v)}{\sum_{u\in\mathcal{V}}D(u)}= divide start_ARG italic_D ( italic_v ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_u ∈ caligraphic_V end_POSTSUBSCRIPT italic_D ( italic_u ) end_ARG (20)

where v𝒱p(v)=1subscript𝑣𝒱𝑝𝑣1\sum_{v\in\mathcal{V}}p(v)=1∑ start_POSTSUBSCRIPT italic_v ∈ caligraphic_V end_POSTSUBSCRIPT italic_p ( italic_v ) = 1.

The entropy of the graph is as follows:

Ent(𝒢)𝐸𝑛𝑡𝒢\displaystyle Ent(\mathcal{G})italic_E italic_n italic_t ( caligraphic_G ) =v𝒱p(v)logp(v)absentsubscript𝑣𝒱𝑝𝑣𝑝𝑣\displaystyle=-\sum_{v\in\mathcal{V}}p(v)\log p(v)= - ∑ start_POSTSUBSCRIPT italic_v ∈ caligraphic_V end_POSTSUBSCRIPT italic_p ( italic_v ) roman_log italic_p ( italic_v ) (21)
=v𝒱D(v)u𝒱D(u)log(D(v)u𝒱D(u))absentsubscript𝑣𝒱𝐷𝑣subscript𝑢𝒱𝐷𝑢𝐷𝑣subscript𝑢𝒱𝐷𝑢\displaystyle=-\sum_{v\in\mathcal{V}}\frac{D(v)}{\sum_{u\in\mathcal{V}}D(u)}% \log(\frac{D(v)}{\sum_{u\in\mathcal{V}}D(u)})= - ∑ start_POSTSUBSCRIPT italic_v ∈ caligraphic_V end_POSTSUBSCRIPT divide start_ARG italic_D ( italic_v ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_u ∈ caligraphic_V end_POSTSUBSCRIPT italic_D ( italic_u ) end_ARG roman_log ( divide start_ARG italic_D ( italic_v ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_u ∈ caligraphic_V end_POSTSUBSCRIPT italic_D ( italic_u ) end_ARG ) (22)
=log(v𝒱D(v))v𝒱D(v)u𝒱D(u)logD(v)absentlogsubscript𝑣𝒱𝐷𝑣subscript𝑣𝒱𝐷𝑣subscript𝑢𝒱𝐷𝑢𝐷𝑣\displaystyle=\text{log}(\sum_{v\in\mathcal{V}}D(v))-\sum_{v\in\mathcal{V}}% \frac{D(v)}{\sum_{u\in\mathcal{V}}D(u)}\log D(v)= log ( ∑ start_POSTSUBSCRIPT italic_v ∈ caligraphic_V end_POSTSUBSCRIPT italic_D ( italic_v ) ) - ∑ start_POSTSUBSCRIPT italic_v ∈ caligraphic_V end_POSTSUBSCRIPT divide start_ARG italic_D ( italic_v ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_u ∈ caligraphic_V end_POSTSUBSCRIPT italic_D ( italic_u ) end_ARG roman_log italic_D ( italic_v ) (23)

where D(v)𝐷𝑣D(v)italic_D ( italic_v ) is the degree of node v𝑣vitalic_v. Ent(𝒢)𝐸𝑛𝑡𝒢Ent(\mathcal{G})italic_E italic_n italic_t ( caligraphic_G ) is the entropy of graph 𝒢𝒢\mathcal{G}caligraphic_G with degree matrix.

The global difficulty of the node is as follows:

Dglobal(v)subscript𝐷𝑔𝑙𝑜𝑏𝑎𝑙𝑣\displaystyle D_{global}(v)italic_D start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT ( italic_v ) =1Ent(v)u𝒱Ent(u)absent1𝐸𝑛𝑡𝑣subscript𝑢𝒱𝐸𝑛𝑡𝑢\displaystyle=1-\frac{Ent(v)}{\sum_{u\in\mathcal{V}}Ent(u)}= 1 - divide start_ARG italic_E italic_n italic_t ( italic_v ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_u ∈ caligraphic_V end_POSTSUBSCRIPT italic_E italic_n italic_t ( italic_u ) end_ARG (24)
Ent(v)𝐸𝑛𝑡𝑣\displaystyle Ent(v)italic_E italic_n italic_t ( italic_v ) =Ent(𝒢)Ent(𝒢^v)absent𝐸𝑛𝑡𝒢𝐸𝑛𝑡subscript^𝒢𝑣\displaystyle=Ent(\mathcal{G})-Ent(\hat{\mathcal{G}}_{v})= italic_E italic_n italic_t ( caligraphic_G ) - italic_E italic_n italic_t ( over^ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) (25)

where Ent(v)𝐸𝑛𝑡𝑣Ent(v)italic_E italic_n italic_t ( italic_v ) is the change if one node and its connections are removed from the network. 𝒢^vsubscript^𝒢𝑣\hat{\mathcal{G}}_{v}over^ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT is the modified graph under the removel of v𝑣vitalic_v. A lower Ent(v)𝐸𝑛𝑡𝑣Ent(v)italic_E italic_n italic_t ( italic_v ) indicates a lower influence on graph structure and is also more difficult. The global difficulty of node v𝑣vitalic_v is to subtract the normalized Ent(v)𝐸𝑛𝑡𝑣Ent(v)italic_E italic_n italic_t ( italic_v ) from 1.

Considering two difficulty measurers from local and global perspectives, we finally define the difficulty of v𝑣vitalic_v as:

D(v)=βDlocal+(1β)Dglobal𝐷𝑣𝛽subscript𝐷𝑙𝑜𝑐𝑎𝑙1𝛽subscript𝐷𝑔𝑙𝑜𝑏𝑎𝑙\displaystyle D(v)=\beta*D_{local}+(1-\beta)*D_{global}italic_D ( italic_v ) = italic_β ∗ italic_D start_POSTSUBSCRIPT italic_l italic_o italic_c italic_a italic_l end_POSTSUBSCRIPT + ( 1 - italic_β ) ∗ italic_D start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT (26)

where β𝛽\betaitalic_β is the weight coefficient assigned to each difficulty measurer to control the balance of the total difficulty.

IV-B2 Data Pruning

With the hierarchical difficulty measurer, we can get a list of nodes sorted in ascending order of nodes based on difficulty. The node at the end of the list is a nuisance for the overall model learning, so should it be retained? The sources of noise in graph neural networks can be varied, firstly, the attribute information of the nodes may contain noise, which affects the representation of the node features and hence the learning of the GNN. Secondly, the presence of anomalous data may cause the spectral energy of the graph to be ”right-shifted”, the distribution of spectral energy shifts from low to high frequencies. These noises will not only reduce the performance of the graph neural network but also propagate through the GNN in the topology, affecting the prediction results of the whole network. In order to solve this problem, we designed a data pruning strategy based on the calculated node difficulty. Specifically, we define a data discarding hyperparameter α𝛼\alphaitalic_α. The value of α𝛼\alphaitalic_α is set while balancing data integrity and model generalization performance. As shown in Fig. 4, the scRNA-seq clustering performance of the scCLG improves after removing the node features with the highest difficulty which prove our hypothesis.

IV-C The Proposed scCLG Algorithm

Our model undergoes a two-phase training process. For the first phase, We pretrain the proposed GNN model ChebAE for discriminative feature learning with an adjacency matrix decoder and ZINB decoder. The number of first phase training rounds is T1subscript𝑇1T_{1}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT epochs. The output of the encoder is a low dimensional vector which is used to calculate node difficulty using a hierarchical difficulty measurer. We retained the top 1α1𝛼1-\alpha1 - italic_α of the data with high sample quality for subsequent training. For the formal training phase, we use the parameters pretrained and train the model for T2subscript𝑇2T_{2}italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT epochs with pruned data. This phase is the learning of clustering tasks. Unlike the pre-training phase, we use all three decoders to optimize the model in more detail. We use the pacing function min(1,2log2λ0log2λ0tT^)𝑚𝑖𝑛1superscript2subscript2subscript𝜆0subscript2subscript𝜆0𝑡^𝑇min(1,2^{\log_{2}\lambda_{0}-\log_{2}\lambda_{0}*\frac{t}{\hat{T}}})italic_m italic_i italic_n ( 1 , 2 start_POSTSUPERSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∗ divide start_ARG italic_t end_ARG start_ARG over^ start_ARG italic_T end_ARG end_ARG end_POSTSUPERSCRIPT ) mentioned in [17] to generate the size of the nodes subset. We illustrate the detailed information in Algorithm 1.

Algorithm 1 scCLG
1:A scRNA-seq data graph 𝒢=(𝒱,,X)𝒢𝒱𝑋\mathcal{G}=(\mathcal{V},\mathcal{E},X)caligraphic_G = ( caligraphic_V , caligraphic_E , italic_X ), the GNN model ChebAE, pre-training epochs T1subscript𝑇1T_{1}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, training epochs T2subscript𝑇2T_{2}italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and data pruning rate α𝛼\alphaitalic_α, hyper-parameters λ0subscript𝜆0\lambda_{0}italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, T^^𝑇\hat{T}over^ start_ARG italic_T end_ARG.
2:The cluster labels Y𝑌Yitalic_Y.
3:# Phase 1: pre-training
4:Initialize parameters of ChebAE
5:Train ChebAE with DecA𝐷𝑒subscript𝑐𝐴Dec_{A}italic_D italic_e italic_c start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and Deczinb𝐷𝑒subscript𝑐𝑧𝑖𝑛𝑏Dec_{zinb}italic_D italic_e italic_c start_POSTSUBSCRIPT italic_z italic_i italic_n italic_b end_POSTSUBSCRIPT on 𝒢𝒢\mathcal{G}caligraphic_G for T1subscript𝑇1T_{1}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT epochs
6:for v𝒱𝑣𝒱v\in\mathcal{V}italic_v ∈ caligraphic_V do
7:     Calculate node difficulty D(v)Eq.(26)formulae-sequence𝐷𝑣Eq26D(v)\leftarrow\text{Eq}.(\ref{eq:diff_sum})italic_D ( italic_v ) ← Eq . ( )
8:end for
9:Sort 𝒱𝒱\mathcal{V}caligraphic_V according to node difficulty in ascending order
10:Prune ordered nodes at the end with a rate of α𝛼\alphaitalic_α
11:# Phase 2: formal training
12:while t<T2ornotconverge𝑡subscript𝑇2𝑜𝑟𝑛𝑜𝑡𝑐𝑜𝑛𝑣𝑒𝑟𝑔𝑒t<T_{2}\ or\ not\ convergeitalic_t < italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_o italic_r italic_n italic_o italic_t italic_c italic_o italic_n italic_v italic_e italic_r italic_g italic_e do
13:     βtmin(1,2log2λ0log2λ0tT^)whereβt<(1α)subscript𝛽𝑡𝑚𝑖𝑛1superscript2subscript2subscript𝜆0subscript2subscript𝜆0𝑡^𝑇𝑤𝑒𝑟𝑒subscript𝛽𝑡1𝛼\beta_{t}\leftarrow min(1,2^{\log_{2}\lambda_{0}-\log_{2}\lambda_{0}*\frac{t}{% \hat{T}}})\ where\ \beta_{t}<(1-\alpha)italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_m italic_i italic_n ( 1 , 2 start_POSTSUPERSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∗ divide start_ARG italic_t end_ARG start_ARG over^ start_ARG italic_T end_ARG end_ARG end_POSTSUPERSCRIPT ) italic_w italic_h italic_e italic_r italic_e italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT < ( 1 - italic_α )
14:     Generate training subset 𝒱t𝒱[1,,βt|𝒱|]subscript𝒱𝑡𝒱1subscript𝛽𝑡𝒱\mathcal{V}_{t}\leftarrow\mathcal{V}[1,\dots,\lfloor\beta_{t}\cdot|\mathcal{V}% |\rfloor]caligraphic_V start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← caligraphic_V [ 1 , … , ⌊ italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⋅ | caligraphic_V | ⌋ ]
15:     Train ChebAE with three decoders on (𝒱,,X[𝒱t])𝒱𝑋delimited-[]subscript𝒱𝑡(\mathcal{V},\mathcal{E},X[\mathcal{V}_{t}])( caligraphic_V , caligraphic_E , italic_X [ caligraphic_V start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] )
16:     tt+1𝑡𝑡1t\leftarrow t+1italic_t ← italic_t + 1
17:end while
18:return Predict Y𝑌Yitalic_Y with ChebAE.

V Experiments

V-A Setup

Dataset. For the former, we collect 7 scRNA-seq datasets from different organisms. The cell numbers range from 870 to 9519, and the cell type numbers vary from 2 to 9.

TABLE I: Summary of the real scRNA-seq datasets.
Dataset Cells Genes Class Platform
QS_Diaphragm 870 23341 5 Smart-seq2
QS_Limb_Muscle 1090 23341 6 Smart-seq2
QS_Lung 1676 23341 11 Smart-seq2
Muraro 2122 19046 9 CEL-seq2
QS_Heart 4365 23341 8 Smart-seq2
Plasschaert 6977 28205 8 inDrop
Wang_Lung 9519 14561 2 10x

Baselines. The performance of scCLG was compared with two traditional clustering methods (Kmeans and Spectral), and several state-of-the-art scRNA-seq data clustering methods including four single-cell deep embedded clustering methods (scziiDesk, scDC, scDCC and scGMAI) and three single-cell deep graph embedded clustering methods (scTAG, scGAE and scGNN).

  • Deep soft K-means clustering with self-training for single-cell RNA sequence data (scziDesk) [10]: It combines a denoising autoencoder to characterize scRNA-seq data while proposing a soft self-training K-means algorithm to cluster the cell population in the learned latent space.

  • Model-based deep embedded clustering method (scDC) [12]: It simultaneously learns to feature representation and clustering via explicit modeling of scRNA-seq data generation.

  • Model-based deep embedding for constrained clustering analysis of single cell RNA-seq data (scDCC) [11] It integrates prior information into the modeling process to guide our deep learning model to simultaneously learn meaningful and desired latent representations and clusters.

  • scGMAI: a Gaussian mixture model for clustering single-cell RNA-Seq data based on deep autoencoder (scGMAI) [22] It utilizes autoencoder networks to reconstruct gene expression values from scRNA-Seq data and FastICA is used to reduce the dimensions of reconstructed data.

  • scGNN is a novel graph neural network framework for single-cell RNA-Seq analyses (scGNN) [13]: It integrates three iterative multi-modal autoencoders and models heterogeneous gene expression patterns using a left-truncated mixture Gaussian model.

  • A topology-preserving dimensionality reduction method for single-cell RNA-seq data using graph autoencoder (scGAE) [14] It builds a cell graph and uses a multitask‑oriented graph autoencoder to preserve topological structure information and feature information in scRNA‑seq data simultaneously.

  • Zinb-based graph embedding autoencoder for single-cell rna-seq interpretations (scTAG) [20] It simultaneously learns cell–cell topology representations and identifies cell clusters based on deep graph convolutional network integrating the ZINB model.

Implementation Details. In the proposed scCLG method, the cell graph was constructed using the KNN algorithm with the nearest neighbor parameter k=20𝑘20k=20italic_k = 20. In the multi-decoders ChebConv graph autoencoder, the hidden fully connected layers in the ZINB decoder are set at 128, 256 and 512. Our algorithm consists of pre-training and formal training, with 1000 and 500 epochs for pre-training and formal training, respectively. Our model was optimized using the Adam optimizer, employing a learning rate of 5e-4 during pre-training and 1e-4 during formal training. The pruning rate α𝛼\alphaitalic_α is set to 0.11. For baseline methods, the parameters were set the same as in the original papers.

TABLE II: Performance comparison between various baselines on seven real datasets.
Metric Methods scCLG scTAG scGAE scGNN scziDesk scDC scDCC scGMAI Kmeans Spectral
ARI QS_Diaphragm 0.9836 0.9628 0.5638 0.5646 0.9517 0.6479 0.8895 0.4111 0.9110 0.9170
QS_Limb_Muscle 0.9828 0.9813 0.5419 0.6399 0.9743 0.5384 0.3449 0.4899 0.8922 0.9615
QS_Lung 0.7946 0.6526 0.2797 0.3631 0.7401 0.4504 0.2908 0.4622 0.7329 0.7559
Muraro 0.8959 0.8878 0.6413 0.5080 0.6784 0.6609 0.7100 0.5132 0.8452 0.8741
QS_Heart 0.9503 0.9371 0.2497 0.5222 0.9324 0.4673 0.2584 0.4368 0.8376 0.8757
Plasschaert 0.7907 0.7697 0.3540 0.4272 0.4867 0.4070 0.4668 0.5711 0.7352 0.2916
Wang_Lung 0.9527 0.9004 0.1035 0.1771 0.8975 0.2520 0.5998 0.1325 0.7995 0.0345
NMI QS_Diaphragm 0.9670 0.9346 0.7351 0.7608 0.9210 0.7807 0.8223 0.6836 0.8846 0.8881
QS_Limb_Muscle 0.9682 0.9616 0.7398 0.7726 0.9468 0.7048 0.4624 0.7198 0.8911 0.9389
QS_Lung 0.8318 0.8038 0.6766 0.6642 0.7543 0.6840 0.4982 0.7312 0.7785 0.7976
Muraro 0.8506 0.8399 0.7619 0.6294 0.7349 0.7549 0.8347 0.7168 0.8194 0.8291
QS_Heart 0.9064 0.8857 0.6039 0.6540 0.8723 0.6531 0.4242 0.6941 0.8299 0.8454
Plasschaert 0.7696 0.7379 0.5563 0.5856 0.6469 0.6122 0.5786 0.5711 0.6915 0.5216
Wang_Lung 0.8942 0.8210 0.3150 0.3975 0.7965 0.1511 0.5862 0.3432 0.7167 0.0367

V-B Clustering Result

Table II shows the clustering performance of our method against multiple state-of-the-art methods, and the values highlighted in bold represent the best results. Obviously, our method outperforms other baseline clustering methods for clustering performance. For the 7 scRNA-seq datasets, scCLG achieves the best NMI and ARI on all datasets. Meanwhile, we can observe that the general deep graph embedded models have no advantage and the clustering performance is not stable. Specifically, scGNN performs poorly on ”Wang_Lung”. The main reason is that the information structure preserved by the cell graph alone cannot address the particularities of scRNA-seq data well, and further data order is necessary, which again proves the superiority of scCLG. The performance of the deep clustering method and traditional clustering method exhibits significant fluctuations across different datasets. However, scCLG still has an advantage. This is because the scCLG could effectively learn the key representations of the scRNA-seq data in a meaningful order so that the model can exhibit a smooth learning trajectory. In summary, we can conclude that scCLG performs better than the other methods under two clustering evaluation metrics.

V-C Parameter Analysis

Refer to caption
Figure 3: Parameter analysis. (A) Comparison of the average ARI and NMI values with different neighbor parameters k𝑘kitalic_k. (B) Comparison of the average ARI and NMI values with different numbers of genes.

V-C1 Different Neighbor Parameter k𝑘kitalic_k Analysis

k𝑘kitalic_k represents the number of nearest neighbors to consider when constructing cell graph. In order to investigate the impact of k𝑘kitalic_k, we ran our model with the parameters 5, 10, 15, 25. Fig. 3 (A) shows the NMI and ARI values with different numbers of k𝑘kitalic_k. As depicted in Fig. 3 (A), we observe that the two metrics first increase rapidly from parameter 5 to 10, reach the best value at k=20𝑘20k=20italic_k = 20, and then decrease slowly from parameter 20 to 25. Therefore, we set the neighbor parameter k as 20 in our scCLG model.

V-C2 Different Numbers of Variable Genes Analysis

In single-cell data analysis, highly variable genes vary significantly among different cells, which helps to reveal the heterogeneity within the cell population and more accurately identify cell subpopulations. To explore the impact of the number of selected highly variable genes, we apply scCLG on real datasets with gene numbers from 300 to 1500. Fig. 3 (B) shows the line plot of the average NMI and ARI on the 7 datasets selecting 300, 500, 1000 and 1500 genes with high variability, respectively. It can be seen that the performance with 500 highly variable genes is better, while the performance with 300 genes is much worse than the others. Therefore, to save computational resources and reduce running time, we set the number of selected high-variance genes in the model to 500.

Refer to caption
Figure 4: Comparison of the average ARI and NMI values with different data pruning rates and pruning strategies.

V-C3 Different Data Pruning Rate Analysis

In single-cell data analysis, data quality can be improved by pruning lower-quality samples thereby affecting the ability to generalize the model. To explore the impact of the selected data, we run our model with pruning rate parameters from 0.06 to 0.21 to drop difficult nodes. We also compared our pruning strategy with two different pruning strategies, namely pruning easy nodes and randomly pruning nodes. Fig. 4 shows the ARI and NMI values with different numbers of α𝛼\alphaitalic_α and pruning strategy. As depicted in Fig. 4, we can observe that the best performance is achieved when the α𝛼\alphaitalic_α is 0.11 and when difficult nodes are pruned. This indicates that the improvement of data quality can significantly improve the performance of the model. Compared to pruning easy nodes and randomly pruning nodes, pruning difficult nodes brings higher profit because difficult nodes have a negative impact on the representation of the graph. Furthermore, randomly pruning nodes is better than pruning easy nodes, indicating the effectiveness of our difficulty measurer which can assign reasonable difficulty scores to nodes.

V-D Ablation Study

TABLE III: Ablation study measured by ARI and NMI values.
Metric Methods scCLG Without CL
ARI QS_Diaphragm 0.9836 0.9778
QS_Limb_Muscle 0.9828 0.9791
QS_Lung 0.7946 0.7947
Muraro 0.8959 0.8897
QS_Heart 0.9503 0.9530
Plasschaert 0.7907 0.7903
Wang_Lung 0.9527 0.9527
NMI QS_Diaphragm 0.9670 0.9579
QS_Limb_Muscle 0.9682 0.9613
QS_Lung 0.8318 0.8321
Muraro 0.8506 0.8468
QS_Heart 0.9064 0.9088
Plasschaert 0.7696 0.7693
Wang_Lung 0.8942 0.8942

In this experiment, we analyzed the effect of each component of the scCLG method. Specifically, we ablated different components in no hierarchical difficulty measurer named Without CL. Table III tabulates the average ARI and NMI values on the 7 datasets with scCLG. As shown in Table III, it can be clearly observed that gene screening and extraction of scRNA-seq data from easy to hard patterns improves the clustering performance. For the 7 scRNA-seq datasets, scCLG achieve the best ARI and NMI on 5 of them. In summary, all components of the scCLG method are reasonable and effective.

VI Conclusion

In this research, we propose a single-cell curriculum learning-based deep graph embedding clustering. Our approach first utilizes the Chebyshev graph convolutional autoencoder to learn the low-dimensional feature representation which preserves the cell–cell topological structure. Then we define two types of difficult nodes and rank the nodes in the graph based on the measured difficulty to train them in a meaningful manner. Meanwhile, we prune the difficult node to keep the high quality of node features. Our method shows higher clustering performance against state-of-the-art approaches for scRNA-seq data. Empirical results provide strong evidence that this performance is imputed to the proposed mechanisms and particularly their ability to tackle the difficult nodes.

References

  • [1] P. Lin, M. Troup, and J. W. Ho, “Cidr: Ultrafast and accurate clustering through imputation for single-cell rna-seq data,” Genome biology, vol. 18, no. 1, pp. 1–11, 2017.
  • [2] D. v. Dijk, J. Nainys, R. Sharma, P. Kaithail, A. J. Carr, K. R. Moon, L. Mazutis, G. Wolf, S. Krishnaswamy, and D. Pe’er, “Magic: A diffusion-based imputation method reveals gene-gene interactions in single-cell rna-sequencing data,” BioRxiv, p. 111591, 2017.
  • [3] M. Huang, J. Wang, E. Torre, H. Dueck, S. Shaffer, R. Bonasio, J. I. Murray, A. Raj, M. Li, and N. R. Zhang, “Saver: gene expression recovery for single-cell rna sequencing,” Nature methods, vol. 15, no. 7, pp. 539–542, 2018.
  • [4] B. Wang, J. Zhu, E. Pierson, D. Ramazzotti, and S. Batzoglou, “Visualization and analysis of single-cell rna-seq data by kernel-based similarity learning,” Nature methods, vol. 14, no. 4, pp. 414–416, 2017.
  • [5] S. Park and H. Zhao, “Spectral clustering based on learning similarity matrix,” Bioinformatics, vol. 34, no. 12, pp. 2069–2076, 2018.
  • [6] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner, “Gradient-based learning applied to document recognition,” Proceedings of the IEEE, vol. 86, no. 11, pp. 2278–2324, 1998.
  • [7] S. Hochreiter and J. Schmidhuber, “Long short-term memory,” Neural computation, vol. 9, no. 8, pp. 1735–1780, 1997.
  • [8] J. Fu, Z. Chen, and X. Han, “Adap dp-fl: Differentially private federated learning with adaptive noise,” in 2022 IEEE International Conference on Trust, Security and Privacy in Computing and Communications (TrustCom).   IEEE, 2022, pp. 656–663.
  • [9] J. Fu, Y. Hong, X. Ling, L. Wang, X. Ran, Z. Sun, W. H. Wang, Z. Chen, and Y. Cao, “Differentially private federated learning: A systematic review,” arXiv preprint arXiv:2405.08299, 2024.
  • [10] L. Chen, W. Wang, Y. Zhai, and M. Deng, “Deep soft k-means clustering with self-training for single-cell rna sequence data,” NAR genomics and bioinformatics, vol. 2, no. 2, p. lqaa039, 2020.
  • [11] T. Tian, J. Zhang, X. Lin, Z. Wei, and H. Hakonarson, “Model-based deep embedding for constrained clustering analysis of single cell rna-seq data,” Nature communications, vol. 12, no. 1, p. 1873, 2021.
  • [12] T. Tian, J. Wan, Q. Song, and Z. Wei, “Clustering single-cell rna-seq data with a model-based deep learning approach,” Nature Machine Intelligence, vol. 1, no. 4, pp. 191–198, 2019.
  • [13] J. Wang, A. Ma, Y. Chang, J. Gong, Y. Jiang, R. Qi, C. Wang, H. Fu, Q. Ma, and D. Xu, “scgnn is a novel graph neural network framework for single-cell rna-seq analyses,” Nature communications, vol. 12, no. 1, p. 1882, 2021.
  • [14] Z. Luo, C. Xu, Z. Zhang, and W. Jin, “A topology-preserving dimensionality reduction method for single-cell rna-seq data using graph autoencoder,” Scientific reports, vol. 11, no. 1, p. 20028, 2021.
  • [15] Y. Bengio, J. Louradour, R. Collobert, and J. Weston, “Curriculum learning,” in Proceedings of the 26th annual international conference on machine learning, 2009, pp. 41–48.
  • [16] L. Jiang, D. Meng, Q. Zhao, S. Shan, and A. Hauptmann, “Self-paced curriculum learning,” in Proceedings of the AAAI Conference on Artificial Intelligence, vol. 29, no. 1, 2015.
  • [17] X. Wei, X. Gong, Y. Zhan, B. Du, Y. Luo, and W. Hu, “Clnode: Curriculum learning for node classification,” in Proceedings of the Sixteenth ACM International Conference on Web Search and Data Mining, 2023, pp. 670–678.
  • [18] C. Gong, J. Yang, and D. Tao, “Multi-modal curriculum learning over graphs,” ACM Transactions on Intelligent Systems and Technology (TIST), vol. 10, no. 4, pp. 1–25, 2019.
  • [19] P. Zhao, Y. Sheng, and X. Zhan, “An advanced nmf-based approach for single cell data clustering,” in 2022 IEEE 2nd International Conference on Information Communication and Software Engineering (ICICSE).   IEEE, 2022, pp. 1–5.
  • [20] Z. Yu, Y. Lu, Y. Wang, F. Tang, K.-C. Wong, and X. Li, “Zinb-based graph embedding autoencoder for single-cell rna-seq interpretations,” in Proceedings of the AAAI conference on artificial intelligence, vol. 36, no. 4, 2022, pp. 4671–4679.
  • [21] X. Ai, “Node importance ranking of complex networks with entropy variation,” Entropy, vol. 19, no. 7, p. 303, 2017.
  • [22] B. Yu, C. Chen, R. Qi, R. Zheng, P. J. Skillman-Lawrence, X. Wang, A. Ma, and H. Gu, “scgmai: a gaussian mixture model for clustering single-cell rna-seq data based on deep autoencoder,” Briefings in bioinformatics, vol. 22, no. 4, p. bbaa316, 2021.