Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
\MakePerPage

footnote

11institutetext: Department of Pathology, The University of Hong Kong 22institutetext: Department of Clinical Neurosciences, University of Cambridge, UK 33institutetext: School of Life Sciences and Technology, Tongji University, China 44institutetext: Zhejiang Lab, China 55institutetext: School of Science and Engineering, University of Dundee, UK 66institutetext: Department of Applied Mathematics and Theoretical Physics, University of Cambridge, UK
66email: cl647@cam.ac.uk

Knowledge-driven Subspace Fusion and Gradient Coordination for Multi-modal Learning

Yupei Zhang* 11    Xiaofei Wang* 22    Fangliangzi Meng 33    Jin Tang 44    Chao Li🖂 225566
Abstract

Multi-modal learning plays a crucial role in cancer diagnosis and prognosis. Current deep learning based multi-modal approaches are often limited by their abilities to model the complex correlations between genomics and histology data, addressing the intrinsic complexity of tumour ecosystem where both tumour and microenvironment contribute to malignancy. We propose a biologically interpretative and robust multi-modal learning framework to efficiently integrate histology images and genomics by decomposing the feature subspace of histology images and genomics, reflecting distinct tumour and microenvironment features. To enhance cross-modal interactions, we design a knowledge-driven subspace fusion scheme, consisting a cross-modal deformable attention module and a gene-guided consistency strategy. Additionally, in pursuit of dynamically optimizing the subspace knowledge, we further propose a novel gradient coordination learning strategy. Extensive experiments demonstrate the effectiveness of the proposed method, outperforming state-of-the-art techniques in three downstream tasks of glioma diagnosis, tumour grading, and survival analysis. Our code is available at https://github.com/helenypzhang/Subspace-Multimodal-Learning. footnotetext: * Equal contribution.

Keywords:
Multi-modal learning Molecular Pathology Cancer diagnosis and prognosis.

1 Introduction

Multi-modal integration of genomics and histology becomes increasingly important for cancer diagnosis, evidenced by the recent shift of cancer taxonomy criteria to integrating molecular markers with histology features [2, 5, 18]. However, joint analysis of multi-modal data at the clinic remains challenging [19]. Automatic algorithms to effectively integrate multi-modal data of genomics and histology promise to offer rapid diagnosis and aid in precise cancer treatment.

Deep learning-based digital pathology [3, 20] holds promise for rapid cancer diagnosis based on whole slide images (WSIs) derived from tissue sections[22]. Research [10, 17] has shown evidence of associations between genomics and WSIs, indicating that morphology features from WSIs may mirror genomic information. However, it remains a challenge to effectively integrate WSIs with genomic information for cancer diagnosis due to (1) the complexity of multi-modal data, i.e., WSIs in gigapixels and genomic profiles of tens of thousands genes and (2) the intrinsic tumour heterogeneity and complexity of cancers.

Previous studies have developed multi-modal approaches to integrate WSIs with genomics for cancer diagnosis. Among them, late-fusion methods integrate modality-specific features at the prediction layer. For instance, Chen et al. [3] integrated genomics and WSIs using the Kronecker Product [16] for cancer diagnosis and prognosis. However, late-fusion models ignore the interactions of the multi-modal data at their earlier learning stage, unable to model the modality interaction for robust feature extraction and utilisation.

By contrast, intermediate-fusion techniques show promise to fuse modality-specific features at various levels before prediction. For instance, Chen et al. [4] introduced a co-attention method that maps correlations between genomics and WSIs for survival analysis, enhancing survival predictions by learning dense co-attention mappings between genomics and bag representations of WSIs. Additionally, Zhou et al. [23] developed a multi-modal learning framework that delves into cross-modal correlations through inter-modality translation and alignment, offering complementary insights from different modalities for survival analysis.

Despite encouraging results, these models typically map WSIs to genomics in a singular embedding space, which may not fully reflect the complexity of cancer. In tumorigenesis, both tumour cells and tumor microenvironment111The tumor microenvironment is a complex ecosystem surrounding tumor cells, mainly composed of immune cells, also with other stromal cells and vessels. (TME) contribute to malignancy [9], providing essential insights with distinct morphological features for histology assessment. Further, evidence [21, 24] shows that genetic profiles are associated with tumour and microenvironment characteristics. For example, isocitrate dehydrogenase (IDH) wildtype has marked necrosis and microvascular proliferation observed in WSIs. Therefore, decoding tumour- and TME-related morphology features from WSIs and genes from genomics promises to advance multi-modal integration. However, it remains challenging to effective model the regional interactions between WSIs and genomics.

Refer to caption
Figure 1: Overview of the proposed framework. Left: Knowledge-driven subspace fusion scheme. Right: Confidence-guided Gradient Coordination.

To tackle the challenge, We propose a novel multi-modal learning framework with a decoupled gene-to-histology integration strategy, for accurate and interpretable automatic cancer diagnosis. Specifically, our contribution is threefold:

  • For the first time, we explicitly decompose the genomics data into tumour- and TME-related genes for effective integration with histological features.

  • We propose a novel knowledge-driven subspace fusion (KS-Fusion) scheme to effectively enhance the multi-modal interactions. Specifically, in KS-Fusion, we present a cross-modal deformable attention (CM-Deform) module with a gene-guided consistency strategy to capture specific morphological features of the corresponding tumour- and TME-related genes.

  • We propose a confidence-guided gradient coordination (CG-Coord) scheme to regulate the multi-modal learning process, which promotes the optimal global performance through dynamic optimization.

Our extensive experiments on the three downstream tasks, i.e., cancer diagnosis, grading, and patient survival prediction, in two public datasets demonstrate that our method outperforms other state-of-the-art (SOTA) methods.

2 Methodology

2.1 Framework

Fig. 1 illustrates the proposed multi-modal learning framework. Overall, to decompose the modal subspace features from genomics and WSIs and model their interactions, we propose a KS-Fusion scheme (left) and a CG-Coord scheme (right). Specifically, as shown in the left part of Fig. 1, categorized genomics features and histology features are first fused through a linear layer to generate a multi-modal teacher, which will be utilized as a query in the deformable attention module. To regulate the deformed offsets, we introduce a batch consistency to regulate the deformed offsets. Then, as illustrated in the right part of Fig. 1, the concatenated multi-modal feature is fed to the classifier for decision prediction, where the proposed CG-Coord scheme can alleviate the conflicts introduced by subspace gradients in training the classifier.

2.2 Knowledge-driven Subspace Fusion

Existing multi-modal fusion methods lack comprehensive modelling of feature subspaces for genomics profiles and WSIs morphological features. To enhance the biological guidance from niche-specific genomic profiles, we propose a KS-Fusion scheme to capture subspace informative features from both tumour- and TME-related gene profiles. Specifically, in the KS-Fusion scheme, the CM-Deform module along with the gene-guided consistency strategy is designated for efficiently extracting histological features of the corresponding tumour- and TME-related genes simultaneously.

Cross-modal Deformable Attention Module. Overall, we build a two-stream neural network to model the cross-modal subspace interactions within 1) tumour-related genomics xgtsuperscriptsubscript𝑥𝑔𝑡x_{g}^{t}italic_x start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT and the WSIs features xpsubscript𝑥𝑝x_{p}italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT; 2) TME-related genomics xgesuperscriptsubscript𝑥𝑔𝑒x_{g}^{e}italic_x start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT and the WSIs features xpsubscript𝑥𝑝x_{p}italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, where t𝑡titalic_t signifies the tumour niche, e𝑒eitalic_e denotes the TME niche, g𝑔gitalic_g stands for genomics and p𝑝pitalic_p represents WSIs. Specifically, the genomic profile encoder is a spiking neural network (SNN)[8] 𝒢𝒢\mathcal{G}caligraphic_G. The genomics data xgsubscript𝑥𝑔x_{g}italic_x start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT is divided into two groups, with xgtsuperscriptsubscript𝑥𝑔𝑡x_{g}^{t}italic_x start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT representing the tumour-related genes and xgesuperscriptsubscript𝑥𝑔𝑒x_{g}^{e}italic_x start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT representing the TME-related genes.

To realize the cross-modal deformable attention, we first apply a linear layer to obtain the multi-modal teacher features xpgtsuperscriptsubscript𝑥𝑝𝑔𝑡x_{pg}^{t}italic_x start_POSTSUBSCRIPT italic_p italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT and xpgesuperscriptsubscript𝑥𝑝𝑔𝑒x_{pg}^{e}italic_x start_POSTSUBSCRIPT italic_p italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT, which are then used to guide the offsets generation through an offsets generation network ψ𝜓\psiitalic_ψ, with two convolution layers and a scaler. Original reference points are a uniform grid of points puHG×WG×2superscript𝑝𝑢superscriptsubscript𝐻𝐺subscript𝑊𝐺2p^{u}\in\mathbb{R}^{H_{G}\times W_{G}\times 2}italic_p start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT × italic_W start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT × 2 end_POSTSUPERSCRIPT, given the input feature map xpguH×W×Csuperscriptsubscript𝑥𝑝𝑔𝑢superscript𝐻𝑊𝐶x_{pg}^{u}\in\mathbb{R}^{H\times W\times C}italic_x start_POSTSUBSCRIPT italic_p italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_C end_POSTSUPERSCRIPT, where u𝑢uitalic_u represents the subspace and can be uniformly denoted as u{t,e}𝑢𝑡𝑒u\in\{t,e\}italic_u ∈ { italic_t , italic_e }. For each stream, the multi-head cross-attention can be denoted as:

qu=xpguWqu,k^u=x^puWku,v^u=x^puWvu,formulae-sequencesuperscript𝑞𝑢superscriptsubscript𝑥𝑝𝑔𝑢superscriptsubscript𝑊𝑞𝑢formulae-sequencesuperscript^𝑘𝑢superscriptsubscript^𝑥𝑝𝑢superscriptsubscript𝑊𝑘𝑢superscript^𝑣𝑢superscriptsubscript^𝑥𝑝𝑢superscriptsubscript𝑊𝑣𝑢q^{u}=x_{pg}^{u}W_{q}^{u},\ \hat{k}^{u}=\hat{x}_{p}^{u}W_{k}^{u},\ \hat{v}^{u}% =\hat{x}_{p}^{u}W_{v}^{u},italic_q start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT = italic_x start_POSTSUBSCRIPT italic_p italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , over^ start_ARG italic_k end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT = over^ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , over^ start_ARG italic_v end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT = over^ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , (1)

where qu,k^u,v^usuperscript𝑞𝑢superscript^𝑘𝑢superscript^𝑣𝑢q^{u},\hat{k}^{u},\hat{v}^{u}italic_q start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , over^ start_ARG italic_k end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , over^ start_ARG italic_v end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT represent the query, deformed key, and deformed value, respectively the Wqu,Wku,Wvusuperscriptsubscript𝑊𝑞𝑢superscriptsubscript𝑊𝑘𝑢superscriptsubscript𝑊𝑣𝑢W_{q}^{u},W_{k}^{u},W_{v}^{u}italic_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT are projection networks. The x^pu=F(xpu;norm(pu+Δpu))superscriptsubscript^𝑥𝑝𝑢𝐹superscriptsubscript𝑥𝑝𝑢normsuperscript𝑝𝑢Δsuperscript𝑝𝑢\hat{x}_{p}^{u}=F(x_{p}^{u};\ {\rm norm}(p^{u}+\Delta p^{u}))over^ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT = italic_F ( italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ; roman_norm ( italic_p start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT + roman_Δ italic_p start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) ), Δpu=ψ(xpgu)Δsuperscript𝑝𝑢𝜓superscriptsubscript𝑥𝑝𝑔𝑢\Delta p^{u}=\psi(x_{pg}^{u})roman_Δ italic_p start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT = italic_ψ ( italic_x start_POSTSUBSCRIPT italic_p italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ), and F𝐹Fitalic_F represents a bilinear interpolation function. For each stream, the output of an attention head is formulated as:

zm;u=softmax(q(m);uk^(m);u/d)v^(m);u,superscript𝑧𝑚𝑢softmaxsuperscript𝑞𝑚𝑢superscript^𝑘𝑚limit-from𝑢top𝑑superscript^𝑣𝑚𝑢z^{m;u}={\rm softmax}(q^{(m);u}\hat{k}^{(m);u\top}/\sqrt{d})\hat{v}^{(m);u},italic_z start_POSTSUPERSCRIPT italic_m ; italic_u end_POSTSUPERSCRIPT = roman_softmax ( italic_q start_POSTSUPERSCRIPT ( italic_m ) ; italic_u end_POSTSUPERSCRIPT over^ start_ARG italic_k end_ARG start_POSTSUPERSCRIPT ( italic_m ) ; italic_u ⊤ end_POSTSUPERSCRIPT / square-root start_ARG italic_d end_ARG ) over^ start_ARG italic_v end_ARG start_POSTSUPERSCRIPT ( italic_m ) ; italic_u end_POSTSUPERSCRIPT , (2)

where m represents the index of the attention head, with the range of 1 to M, and the final output is calculated as:

zu=concat(z1;u,,zM;u)Wou,superscript𝑧𝑢concatsuperscript𝑧1𝑢superscript𝑧𝑀𝑢superscriptsubscript𝑊𝑜𝑢z^{u}={\rm concat}(z^{1;u},...,z^{M;u})W_{o}^{u},italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT = roman_concat ( italic_z start_POSTSUPERSCRIPT 1 ; italic_u end_POSTSUPERSCRIPT , … , italic_z start_POSTSUPERSCRIPT italic_M ; italic_u end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT , (3)

where Wousuperscriptsubscript𝑊𝑜𝑢W_{o}^{u}italic_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT is a projection network. In this way, the informative morphological features are dominated by the multi-modal features, enhancing the correlations and interactions between subspace genomic and histology embeddings.

Refer to caption
Figure 2: Illustration of gene-guided consistency strategy, which will regulate the deformation by the sample-wise similarity constraint.

2.2.1 Gene-guided consistency strategy.

Furthermore, to ensure the intrinsic consistency between genomics and the deformed WSIs features, we introduce a batch consistency, denoted as gene-guided consistency strategy (Ge-Con), to further regulate offset adjustments in each subspace, which can be formulated as:

batchu=1B𝒮b(𝒢(xgu))𝒮b(pu+Δpu)2,subscriptsuperscript𝑢batch1𝐵subscriptnormsubscript𝒮𝑏𝒢superscriptsubscript𝑥𝑔𝑢subscript𝒮𝑏superscript𝑝𝑢Δsuperscript𝑝𝑢2\mathcal{L}^{u}_{\rm batch}\ =\ \frac{1}{B}||\mathcal{S}_{b}(\mathcal{G}(x_{g}% ^{u}))\ -\ \mathcal{S}_{b}(p^{u}+\Delta p^{u})||_{2},caligraphic_L start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_batch end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_B end_ARG | | caligraphic_S start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ( caligraphic_G ( italic_x start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) ) - caligraphic_S start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT + roman_Δ italic_p start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , (4)

where 𝒮b(x1)subscript𝒮𝑏subscript𝑥1\mathcal{S}_{b}(x_{1})caligraphic_S start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) denotes the Gram matrix of feature x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, representing the correlations among individuals, and B𝐵Bitalic_B represents the batch size. As shown in Fig. 2, the sample-wise similarity further enhances the guidance of genomics. This gene-knowledge penetration towards WSIs features further enhances the subspace informative feature interaction and interaction, which ultimately facilitates multi-modal subspace fusion.

2.3 Confidence-guided Gradient Coordination

In training the classifier with multi-modal features, it is challenging to obtain a global optimum performance because the subspace gradients of tumour-gene and TME-gene may conflict when applying joint optimization. Therefore, to boost the downstream task by combining these two types of domain knowledge, we design the CG-Coord scheme to obtain the global training optimum via dynamic gradient regulation.

Specifically, assuming that g(θt)𝑔superscript𝜃𝑡g(\theta^{t})italic_g ( italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) and g(θe)𝑔superscript𝜃𝑒g(\theta^{e})italic_g ( italic_θ start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) are in conflict when cosine(g(θt),g(θe))cosine𝑔superscript𝜃𝑡𝑔superscript𝜃𝑒{\rm cosine}(g(\theta^{t}),g(\theta^{e}))roman_cosine ( italic_g ( italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) , italic_g ( italic_θ start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) ) is smaller than zero, we adjust the gradient with less prediction confidence score when they conflict with each other. The adjustment process is formulated as:

{g~(θt)=Γ(g(θt),g(θe)),st<se,g~(θe)=Γ(g(θe),g(θt)),se<st,cases~𝑔superscript𝜃𝑡Γ𝑔superscript𝜃𝑡𝑔superscript𝜃𝑒superscript𝑠𝑡superscript𝑠𝑒~𝑔superscript𝜃𝑒Γ𝑔superscript𝜃𝑒𝑔superscript𝜃𝑡superscript𝑠𝑒superscript𝑠𝑡\begin{cases}\ \tilde{g}(\theta^{t})\ =\ \Gamma(g(\theta^{t}),\ g(\theta^{e}))% ,&{\sum s^{t}<\sum s^{e}},\\ \ \tilde{g}(\theta^{e})\ =\ \Gamma(g(\theta^{e}),\ g(\theta^{t})),&{\sum s^{e}% <\sum s^{t}},\end{cases}{ start_ROW start_CELL over~ start_ARG italic_g end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) = roman_Γ ( italic_g ( italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) , italic_g ( italic_θ start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) ) , end_CELL start_CELL ∑ italic_s start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT < ∑ italic_s start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , end_CELL end_ROW start_ROW start_CELL over~ start_ARG italic_g end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) = roman_Γ ( italic_g ( italic_θ start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) , italic_g ( italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ) , end_CELL start_CELL ∑ italic_s start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT < ∑ italic_s start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , end_CELL end_ROW (5)

where st=softmax(𝒟(zt))[k]superscript𝑠𝑡softmax𝒟superscript𝑧𝑡delimited-[]𝑘s^{t}={\rm softmax}(\mathcal{D}(z^{t}))[k]italic_s start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = roman_softmax ( caligraphic_D ( italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ) [ italic_k ], se=softmax(𝒟(ze))[k]superscript𝑠𝑒softmax𝒟superscript𝑧𝑒delimited-[]𝑘s^{e}={\rm softmax}(\mathcal{D}(z^{e}))[k]italic_s start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT = roman_softmax ( caligraphic_D ( italic_z start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) ) [ italic_k ], and 𝒟𝒟\mathcal{D}caligraphic_D is the decoder for the diagnosis and grading tasks. The confidence score for survival is represented by corresponding C-Index values. In equation 5, Γ(x1,x2)Γsubscript𝑥1subscript𝑥2\Gamma(\vec{x}_{1},\vec{x}_{2})roman_Γ ( over→ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over→ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) denotes prjection of the vector x1subscript𝑥1\vec{x}_{1}over→ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT to the perpendicular direction of vector x2subscript𝑥2\vec{x}_{2}over→ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Besides, stsuperscript𝑠𝑡\sum s^{t}∑ italic_s start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT and sesuperscript𝑠𝑒\sum s^{e}∑ italic_s start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT represent the sum of the corresponding prediction scores on the kth𝑡-th- italic_t italic_h label from tumour and TME branches in a mini-batch. In this way, gradients of different subspace features will be modulated to avoid conflicts dynamically, thereby achieving harmonious training.

2.4 Training Objectives of Downstream Tasks

In our framework, the training loss functions are tailored for varied downstream tasks, including cancer diagnosis, tumour grading, and prognosis prediction. Specifically, for the cancer diagnosis and grading tasks, we apply the cross-entropy loss as the task-specific objective, and the total training objectives can be formulated as:

diagsubscriptdiag\displaystyle\mathcal{L}_{\rm diag}caligraphic_L start_POSTSUBSCRIPT roman_diag end_POSTSUBSCRIPT =LCE(𝒟(zt,ze;θdiag),Ydiag)+αbatcht+(1α)batche,absentsubscript𝐿CE𝒟superscript𝑧𝑡superscript𝑧𝑒subscript𝜃𝑑𝑖𝑎𝑔subscript𝑌diag𝛼superscriptsubscript𝑏𝑎𝑡𝑐𝑡1𝛼superscriptsubscript𝑏𝑎𝑡𝑐𝑒\displaystyle=L_{\rm CE}(\mathcal{D}(z^{t},z^{e};\ \theta_{diag}),\ Y_{\rm diag% })+\alpha\mathcal{L}_{batch}^{t}+(1-\alpha)\mathcal{L}_{batch}^{e},= italic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( caligraphic_D ( italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_z start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_d italic_i italic_a italic_g end_POSTSUBSCRIPT ) , italic_Y start_POSTSUBSCRIPT roman_diag end_POSTSUBSCRIPT ) + italic_α caligraphic_L start_POSTSUBSCRIPT italic_b italic_a italic_t italic_c italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + ( 1 - italic_α ) caligraphic_L start_POSTSUBSCRIPT italic_b italic_a italic_t italic_c italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , (6)
gradsubscriptgrad\displaystyle\mathcal{L}_{\rm grad}caligraphic_L start_POSTSUBSCRIPT roman_grad end_POSTSUBSCRIPT =LCE(𝒟(zt,ze;θgrad),Ygrad)+αbatcht+(1α)batche,absentsubscript𝐿CE𝒟superscript𝑧𝑡superscript𝑧𝑒subscript𝜃𝑔𝑟𝑎𝑑subscript𝑌grad𝛼superscriptsubscript𝑏𝑎𝑡𝑐𝑡1𝛼superscriptsubscript𝑏𝑎𝑡𝑐𝑒\displaystyle=L_{\rm CE}(\mathcal{D}(z^{t},z^{e};\ \theta_{grad}),\ Y_{\rm grad% })+\alpha\mathcal{L}_{batch}^{t}+(1-\alpha)\mathcal{L}_{batch}^{e},= italic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( caligraphic_D ( italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_z start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_g italic_r italic_a italic_d end_POSTSUBSCRIPT ) , italic_Y start_POSTSUBSCRIPT roman_grad end_POSTSUBSCRIPT ) + italic_α caligraphic_L start_POSTSUBSCRIPT italic_b italic_a italic_t italic_c italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + ( 1 - italic_α ) caligraphic_L start_POSTSUBSCRIPT italic_b italic_a italic_t italic_c italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , (7)

where CEsubscriptCE\mathcal{L}_{\rm CE}caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT represents the cross-entropy loss, 𝒟𝒟\mathcal{D}caligraphic_D represents a classifier, θdiagsubscript𝜃diag\theta_{\rm diag}italic_θ start_POSTSUBSCRIPT roman_diag end_POSTSUBSCRIPT and θgradsubscript𝜃grad\theta_{\rm grad}italic_θ start_POSTSUBSCRIPT roman_grad end_POSTSUBSCRIPT represent the classifier parameters for diagnosis and grading tasks, respectively. Ydiagsubscript𝑌diagY_{\rm diag}italic_Y start_POSTSUBSCRIPT roman_diag end_POSTSUBSCRIPT and Ygradsubscript𝑌gradY_{\rm grad}italic_Y start_POSTSUBSCRIPT roman_grad end_POSTSUBSCRIPT represent the diagnosis and grading labels, respectively. The α𝛼\alphaitalic_α is a hyper-parameter for balancing subspace gene knowledge penetration. Notably, the hyper-parameter sensitivity analysis study can be found in Section 3.2.

Besides, for the prognosis prediction task, we adopt the NLL (negative log-likelihood) survival loss [23], denoted as NLLsubscriptNLL\mathcal{L}_{\rm NLL}caligraphic_L start_POSTSUBSCRIPT roman_NLL end_POSTSUBSCRIPT, as the task-specific objective for the survival outcome prediction, and the total training objective is:

surv=NLL+αbatcht+(1α)batche,subscriptsurvsubscriptNLL𝛼superscriptsubscript𝑏𝑎𝑡𝑐𝑡1𝛼superscriptsubscript𝑏𝑎𝑡𝑐𝑒\mathcal{L}_{\rm surv}=\mathcal{L}_{\rm NLL}+\alpha\mathcal{L}_{batch}^{t}+(1-% \alpha)\mathcal{L}_{batch}^{e},caligraphic_L start_POSTSUBSCRIPT roman_surv end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT roman_NLL end_POSTSUBSCRIPT + italic_α caligraphic_L start_POSTSUBSCRIPT italic_b italic_a italic_t italic_c italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + ( 1 - italic_α ) caligraphic_L start_POSTSUBSCRIPT italic_b italic_a italic_t italic_c italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , (8)

In this way, our method can benefit downstream tasks by promoting subspace cross-modal fusion and harmonious training.

3 Experiments & Results

3.1 Datasets & Implementation Details

Datasets We conduct experiments on two public datasets, i.e., TCGA GBM-LGG [15] dataset and IvyGAP [12, 13] dataset. The two datasets, focusing on gliomas, are merged as a meta dataset for better performance, including 2,387 tissue samples (668 cases) with paired WSIs and genomics profiles. We randomly split it into training (534 cases), testing (68 cases), and validation (66 cases) sets.

WSIs are crop into patches sized at 224px×224px224px224px{\rm 224px}\times{\rm 224px}224 roman_p roman_x × 224 roman_p roman_x of 0.5μmpx10.5𝜇msuperscriptpx10.5\mu{\rm m\ px^{-1}}0.5 italic_μ roman_m roman_px start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT. Following [18], for each WSI, we extract 2500 patches with the biological repeat strategy. For genomics process, according to previous studies [1], we first sort the shared gene signatures in the TCGA GBM-LGG [15] and IvyGAP [12, 13] datasets according to the expression variance and then select the top 30%, to capture important biological information in genes. The genomic profile is enriched with 420 features, encompassing 59 tumour-related genes and 361 TME-related genes.

Implementation Details All experiments were conducted using the PyTorch [11] on two NVIDIA RTX A5000 GPUs, with a batch size of 8. Our method was trained for 20 epochs for diagnosis and grading, and 10 epochs for survival prediction. Network optimization was performed using the Adam optimizer [7]. The key hyper-parameters can be found in Supplementary Table I. Each hyperparameter was tuned to achieve optimal performance on the validation set.

3.2 Performance Evaluation

Table 1: Comparison with SOTA methods on three tasks. p. and g. represent the modality of pathology and genomics, respectively. Best and second results are highlighted with bold and underline.
Methods p. g. Diagnosis, % Grading, % Survival, %
AUC Acc. Sen. Spec. F1-score AUC Acc. Sen. Spec. F1-score C-Index
AttMIL [6] \checkmark 80.25 53.81 38.77 81.99 33.30 79.70 61.02 61.36 80.77 60.30 55.59
TransMIL [14] \checkmark 74.90 51.69 44.06 84.11 39.24 83.14 68.64 67.69 84.21 66.83 67.71
SNN [8] \checkmark 90.87 76.27 64.79 92.80 64.94 90.52 83.05 81.92 91.18 81.72 77.04
Concat \checkmark \checkmark 90.89 74.58 63.01 92.35 63.64 89.32 78.39 76.76 88.70 75.00 75.06
Add \checkmark \checkmark 91.67 75.85 63.90 92.67 63.56 90.36 82.63 81.60 91.01 81.39 73.42
Bilinear \checkmark \checkmark 90.25 79.24 70.33 93.50 69.87 82.49 70.34 68.27 84.55 62.82 73.71
MCAT [4] \checkmark \checkmark 80.70 55.08 39.85 84.68 35.81 87.40 59.75 62.63 80.54 50.93 69.64
CMAT [23] \checkmark \checkmark 88.14 68.64 59.01 89.87 53.61 87.54 54.66 57.36 78.07 46.85 65.96
Ours \checkmark \checkmark 95.28 81.36 72.92 94.44 70.97 91.53 83.05 82.29 91.29 82.43 79.78

Baselines and SOTA Comparison Methods For each task, we compare our model with eight SOTA methods, with three uni-modal methods: AttMIL (Histology only)[6], TransMIL (Histology only) [14], SNN (Genomics only)[8]) and five multi-modal fusion algorithms: Concat (AttMIL with SNN), Add (AttMILwith SNN), Bilinear (ResNet with SNN), MCAT [4], and CMAT [23].

Downstream Task I: Glioma Diagnosis As shown in Table 1, in gliomas diagnosis task, our framework is superior to all SOTA models, achieving AUC of 95.28%percent95.2895.28\%95.28 %, Acc. of 81.36%percent81.3681.36\%81.36 %, Sen. of 72.92%percent72.9272.92\%72.92 %, Spec. of 94.44%percent94.4494.44\%94.44 %, and F1-score of 70.97%percent70.9770.97\%70.97 %. In terms of AUC and Acc., our framework outperforms others by at least 3.61%percent3.613.61\%3.61 % and 2.12%percent2.122.12\%2.12 %, respectively, indicating the superior multi-modal learning ability of our method. Furthermore, ROCs can be found in Fig. 3 (top).

Downstream Task II: Glioma Grading. Gliomas can be categorized into four severity grades. As shown in the middle panel of Table 1, our framework outperforms all SOTA models, achieving AUC of 91.53%percent91.5391.53\%91.53 %, the Acc. of 83.05%percent83.0583.05\%83.05 %, the Sen. of 82.29%percent82.2982.29\%82.29 %, the Spec. of 91.29%percent91.2991.29\%91.29 %, and the F1-score of 82.43%percent82.4382.43\%82.43 %. In terms of AUC and F1-score, our framework outperforms others by at least 1.01%percent1.011.01\%1.01 % and 0.71%percent0.710.71\%0.71 %, respectively, suggesting our excellent ability in the grading tasks.

Downstream Task III: Survival Analysis Following previous studies [3, 4], we segment the overall survival time into four intervals based on the quartiles of event times of uncensored patients to compute the discretized-survival C-index. As shown in Table 1, our framework outperforms state-of-the-art models, achieving the C-Index of 79.78%percent79.7879.78\%79.78 %, which outperforms the second-best method (marked in underline) by 2.74%percent2.742.74\%2.74 %. The results show that our proposed multi-modal learning framework is also powerful in the prognosis prediction task.

Hyper-parameters Sensitivity Analysis An additional hyper-parameter sensitivity analysis experiment is conducted to investigate the trade-off between tumour and TME subspaces in the proposed KS-Fusion scheme. Specifically, we conduct experiments on the hyperparameter α𝛼\alphaitalic_α in Eq. 6 and Eq. 7. As shown in Fig. 3 (bottom), our method achieves the best performance when α𝛼\alphaitalic_α equals 0.5, indicating that we can efficiently utilise both tumour- and TME-related gene features.

Ablation Study

Table 2: Ablation studies on Diagnosis, Grading, and Survival tasks. The best results are highlighted with bold.
Methods Diagnosis, % Grading, % Survival, %
AUC Acc. Sen. Spec. F1-score AUC Acc. Sen. Spec. F1-score C-Index
w/oGe-Con𝑤𝑜Ge-Conw/o\ {\rm Ge\text{-}Con}italic_w / italic_o roman_Ge - roman_Con 91.55 79.66 69.43 93.89 67.07 90.49 71.61 71.82 85.77 72.04 72.16
w/oCG-Coord𝑤𝑜CG-Coordw/o\ {\rm CG\text{-}Coord}italic_w / italic_o roman_CG - roman_Coord 91.27 77.97 67.47 93.45 67.70 89.08 73.73 73.90 86.82 74.10 76.54
Ours 95.28 81.36 72.92 94.44 70.97 91.53 83.05 82.29 91.29 82.43 79.78

To quantitatively evaluate the effectiveness of our proposed components, we conduct ablation studies for each component on two downstream tasks, i.e., cancer diagnosis and grading. In particular, we apply two ablative baselines of the proposed framework by disabling the gene-guided consistency strategy (denoted as w/oGe-Con𝑤𝑜Ge-Conw/o\ {\rm Ge\text{-}Con}italic_w / italic_o roman_Ge - roman_Con) and Confidence-guided Gradient Coordination (denoted as w/oCG-Coord𝑤𝑜CG-Coordw/o\ {\rm CG\text{-}Coord}italic_w / italic_o roman_CG - roman_Coord). As illustrated in Table 2, the AUC shows an improvement of 3.73%percent3.733.73\%3.73 % and 4.01%percent4.014.01\%4.01 % for the Ge-Con and CG-Coord correspondingly on the glioma diagnosis task. In terms of grading, the AUC increases 1.04%percent1.041.04\%1.04 % and 2.45%percent2.452.45\%2.45 % for Ge-Con and CG-Coord, respectively, indicating the effectiveness of our gene-guided consistency strategy and CG-Coord scheme.

Refer to caption
Figure 3: Top: ROCs of comparison and ablation study on glioma diagnosis task. Bottom: Hyper-parameter analysis of α𝛼\alphaitalic_α in diagnosis and grading tasks.

4 Conclusion

Multi-modal data plays an increasingly important role in recent cancer diagnosis criteria. In order to effectively model the multi-modal data of histology and genomics, we propose a multi-modal learning framework with KS-fusion scheme, reflecting the intrinsic cancer mechanisms of tumour and TME. In the KS-fusion scheme, we propose the CM-Deform module and gene-guided consistency strategy to enhance the multi-modal interaction among tumour-genes, TME-genes and histology images. Besides, we also design a CG-Coord scheme to stabilize the multi-modal learning process, via dynamically adjusting the optimization of subspace features. Extensive experiments conducted on three tasks demonstrate the effectiveness of our method, with a significant performance improvement compared to other SOTA methods, promising to promote precision oncology.

References

  • [1] Bhattacharya, S., Dunn, P., Thomas, C.G., Smith, B., Schaefer, H., Chen, J., Hu, Z., Zalocusky, K.A., Shankar, R.D., Shen-Orr, S.S., et al.: Immport, toward repurposing of open access immunological assay data for translational and clinical research. Scientific data 5(1),  1–9 (2018)
  • [2] Bombonati, A., Sgroi, D.C.: The molecular pathology of breast cancer progression. The Journal of pathology 223(2), 308–318 (2011)
  • [3] Chen, R.J., Lu, M.Y., Wang, J., Williamson, D.F., Rodig, S.J., Lindeman, N.I., Mahmood, F.: Pathomic fusion: an integrated framework for fusing histopathology and genomic features for cancer diagnosis and prognosis. IEEE Transactions on Medical Imaging 41(4), 757–770 (2020)
  • [4] Chen, R.J., Lu, M.Y., Weng, W.H., Chen, T.Y., Williamson, D.F., Manz, T., Shady, M., Mahmood, F.: Multimodal co-attention transformer for survival prediction in gigapixel whole slide images. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 4015–4025 (2021)
  • [5] Harris, T.J., McCormick, F.: The molecular pathology of cancer. Nature reviews Clinical oncology 7(5), 251–265 (2010)
  • [6] Ilse, M., Tomczak, J., Welling, M.: Attention-based deep multiple instance learning. In: International conference on machine learning. pp. 2127–2136. PMLR (2018)
  • [7] Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
  • [8] Klambauer, G., Unterthiner, T., Mayr, A., Hochreiter, S.: Self-normalizing neural networks. Advances in neural information processing systems 30 (2017)
  • [9] Li, C., Wang, S., Liu, P., Torheim, T., Boonzaier, N.R., van Dijken, B.R., Schönlieb, C.B., Markowetz, F., Price, S.J.: Decoding the interdependence of multiparametric magnetic resonance imaging to reveal patient subgroups correlated with survivals. Neoplasia 21(5), 442–449 (2019)
  • [10] Lin, T.C., Yeh, Y.M., Fan, W.L., Chang, Y.C., Lin, W.M., Yang, T.Y., Hsiao, M.: Ghrelin upregulates oncogenic aurora a to promote renal cell carcinoma invasion. Cancers 11(3),  303 (2019)
  • [11] Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al.: Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems 32 (2019)
  • [12] Puchalski, R.B., Shah, N., Miller, J., Dalley, R., Nomura, S.R., Yoon, J.G., Smith, K.A., Lankerovich, M., Bertagnolli, D., Bickley, K., et al.: An anatomic transcriptional atlas of human glioblastoma. Science 360(6389), 660–663 (2018)
  • [13] Shah, N., Feng, X., Lankerovich, M., Puchalski, R.B., Keogh, B.: Data from ivy gap. The Cancer Imaging Archive 10,  K9 (2016)
  • [14] Shao, Z., Bian, H., Chen, Y., Wang, Y., Zhang, J., Ji, X., et al.: Transmil: Transformer based correlated multiple instance learning for whole slide image classification. Advances in neural information processing systems 34, 2136–2147 (2021)
  • [15] Tomczak, K., Czerwińska, P., Wiznerowicz, M.: Review the cancer genome atlas (tcga): an immeasurable source of knowledge. Contemporary Oncology/Współczesna Onkologia 2015(1), 68–77 (2015)
  • [16] Van Loan, C.F.: The ubiquitous kronecker product. Journal of computational and applied mathematics 123(1-2), 85–100 (2000)
  • [17] Wang, Q.M., Lv, L., Tang, Y., Zhang, L., Wang, L.F.: Mmp-1 is overexpressed in triple-negative breast cancer tissues and the knockdown of mmp-1 expression inhibits tumor cell malignant behaviors in vitro. Oncology letters 17(2), 1732–1740 (2019)
  • [18] Wang, X., Price, S., Li, C.: Multi-task learning of histology and molecular markers for classifying diffuse glioma. arXiv preprint arXiv:2303.14845 (2023)
  • [19] Wei, Y., Chen, X., Zhu, L., Zhang, L., Schönlieb, C.B., Price, S., Li, C.: Multi-modal learning for predicting the genotype of glioma. IEEE Transactions on Medical Imaging (2023)
  • [20] Xing, X., Chen, Z., Zhu, M., Hou, Y., Gao, Z., Yuan, Y.: Discrepancy and gradient-guided multi-modal knowledge distillation for pathological glioma grading. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 636–646. Springer (2022)
  • [21] Zeng, Y., Zeng, Y., Yin, H., Chen, F., Wang, Q., Yu, X., Zhou, Y.: Exploration of the immune cell infiltration-related gene signature in the prognosis of melanoma. Aging (albany NY) 13(3),  3459 (2021)
  • [22] Zhang, L., Wei, Y., Fu, Y., Price, S., Schönlieb, C.B., Li, C.: Mutual contrastive low-rank learning to disentangle whole slide image representations for glioma grading. arXiv preprint arXiv:2203.04013 (2022)
  • [23] Zhou, F., Chen, H.: Cross-modal translation and alignment for survival analysis. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 21485–21494 (2023)
  • [24] Zhou, M., Zhang, Z., Bao, S., Hou, P., Yan, C., Su, J., Sun, J.: Computational recognition of lncrna signature of tumor-infiltrating b lymphocytes with potential implications in prognosis and immunotherapy of bladder cancer. Briefings in Bioinformatics 22(3), bbaa047 (2021)