1 Introduction

Deep convolutional networks achieve great success on large scale vision tasks such as ImageNet (Russakovsky et al., 2015) and Places365 (Zhou et al., 2017). In addition to their notable improvements of accuracy, deep representations learned on modern CNNs are demonstrated transferable across relevant tasks (Yosinski et al., 2014). This is rather fortunate for many real world applications with insufficient labeled examples. Transfer learning aims to obtain good performance on such tasks by leveraging knowledge learned by relevant large scale datasets. The auxiliary and desired tasks are called the source and target tasks respectively. According to Pan and Yang (2009), we focus on inductive transfer learning, which cares about the situation that the source and target tasks have different label spaces.

In context of deep learning, the most popular practice of inductive transfer learning is to fine-tune a model pre-trained on the source task with the \(\mathrm {L^2}\) regularization, which has the effect of constraining the parameter around the origin of zeros. Li et al. (2018) points out that since the parameter may be driven far from the starting point of the pre-trained model, a major disadvantage of naive fine-tuning is the risk of catastrophic forgetting (Kirkpatrick et al., 2017) of the knowledge learned from source. They recommend to use the \(\mathrm {L^2\text {-}SP}\) regularizer instead of the popular \(\mathrm {L^2}\). While in parallel, knowledge distillation, which is originally designed for compressing the knowledge in a complex model to a simple one (Hinton et al., 2015), is proved to be useful for inductive transfer learning, though knowledge is distilled from a different (but relevant) dataset (Zagoruyko & Komodakis, 2016; Yim et al., 2017). Recent work (Li et al., 2019) formulates knowledge distillation in transfer learning as a regularizer on features and further improves through unactivated channel reusing for better fitting the training samples. These methods adopt a common assumption that, the relevant knowledge contained in the source model is expected to facilitate the generalization of the target model. This leads to the motivation of regularizing the fine-tuned model using the starting point as the reference (SPAR) in a mainstream group of inductive transfer learning algorithms (Li et al., 2018; Zagoruyko & Komodakis, 2016; Li et al., 2019; Gouk et al., 2020; Li & Zhang, 2021).

Although the advanced methods with SPAR have succeed to preserve the knowledge contained in the source model, fine-tuning also takes the obvious risk of negative transfer (Torrey & Shavlik, 2010). Intuitively, if the source and target data distribution are dissimilar to some extent, not all the knowledge from the source is transferable to the target and an indiscriminate transfer could be detrimental to the target model. However, the impact of negative transfer has been rarely considered in inductive transfer learning studies. The most related work, Chen et al. (2019), proposed to investigate the regularizer of Batch Spectral Shrinkage (\(\textrm{BSS}\)) to inhibit negative transfer, where the small singular values of feature representations are suppressed as they are regarded as not transferable. Yet, it is hard to adaptively determine the scope of small singular value when faced with different target tasks. Moreover, \(\textrm{BSS}\) does not take consideration of the catastrophic forgetting risk, which means it has to be equipped with other fine-tuning techniques (e.g., \(\mathrm {L^2\text {-}SP}\)(Li et al., 2018), Attention-Transfer (Zagoruyko & Komodakis, 2016), \(\textrm{DELTA}\)(Li et al., 2019), Distance-based Regularization (Gouk et al., 2020; Li & Zhang, 2021), etc.) to achieve considerable performance.

According to the above analysis, it is reasonable to think about a better solution which simultaneously takes the consideration of preserving relevant knowledge and avoiding negative transfer. In this paper, we intend to improve the standard fine-tuning paradigm by accurate knowledge transfer. Assuming that the knowledge contained in the source model consists of one part relevant to the target task and the other part which is irrelevant,Footnote 1 we are going to explicitly disentangle the former from the source model. Thus, a target-specific starting point is used as the reference instead of the original one. Specifically, we design a novel regularizer of deep transfer learning through Target-awareness REpresentation Disentanglement (\(\textrm{TRED}\)). The whole algorithm includes two steps. First we use a lightweight disentangler to separate middle representations of the pre-trained source model into the positive and negative parts. The core component of the disentangler is a differentiable module which is capable of separating the positive and negative group of features. To implement the desired disentanglement, the classical metric of Maximum Mean Discrepancy is a preferred choice due to its simpleness and robustness. Specifically, the relevant and irrelevant part from noisy features are separated by maximizing MMD along the spatial dimension, i.e., their visual attention. Besides, other approaches of feature disentanglement can also be applicable in our algorithm. To verify this, we propose to minimize the Mutual Information (Min-MI) along the channel dimension, and its combination with Max-MMD. By involving more trainable parameters and preserving more relevant information, the fine-tuning accuracy can be further improved.

Reliable disentanglement is achieved by simultaneously optimizing the disentangling module and ensuring to reconstruct the original representation. Supervision information from labeled target examples is utilized to distinguish the positive part from the negative part. The second step is to perform fine-tuning using the disentangled positive part of representations as the reference. In summary, our main contributions are as following:

  • We are the first to involve the idea of representation disentanglement to improve inductive transfer learning.

  • Our algorithm aiming at accurate knowledge transfer contributes to the study of negative transfer.

  • Our proposed \(\textrm{TRED}\) significantly outperforms state-of-the-art transfer learning regularizers including \(\mathrm {L^2}\), \(\mathrm {L^2\text {-}SP}\), \(\textrm{AT}\), \(\textrm{DELTA}\)  \(\textrm{BSS}\) and \(\textrm{DistReg}\) on various real world datasets.

Table 1 Comparison among different fine-tuning approaches
Fig. 1
figure 1

The architecture of deep transfer learning through Target-awareness Representation Disentanglement (\(\textrm{TRED}\))

2 Related work

2.1 Shrinkage towards chosen parameters

Regularization techniques have a long history since Stein’s paradox (Stein, 1956; Efron & Morris, 1977), showing that shrinking towards chosen parameters obtains an estimate more accurate than simply using observed averages. Most common choices like Lasso and Ridge Regression pull the model towards zero, while it is widely believed that shrinking towards “true parameters" is more effective. In transfer learning, models pre-trained on relevant source tasks with sufficient labeled examples are often regarded as “true parameters". Earlier works demonstrate its effectiveness on Maximum Entropy (Chelba & Acero, 2006) or Support Vector Machine models (Yang et al., 2007; Li, 2007; Aytar & Zisserman, 2011).

Despite that the notion of “true parameters" of deep neural networks is intractable, several studies (Bengio, 2012; Yosinski et al., 2014; Oquab et al., 2014) have demonstrated the great transferability of representations trained on large scale datasets for general purpose. Liu et al. (2019) theoretically studies properties of the pre-traind model and explains why it outperforms training from scratch.

2.2 Deep inductive transfer learning

Despite the great transferability, recent work (Li et al., 2018) points out that naive fine-tuning the pre-trained model with \(\mathrm {L^2}\) regularization may cause losses of the original knowledge. In order to overcome over-fitting, various transfer learning regularizers have been proposed. According to the type of regularized objectives, they can be categorized as parameter based (Li et al., 2018; Gouk et al., 2020; Li & Zhang, 2021), feature based (Zagoruyko & Komodakis, 2016; Yim et al., 2017; Li et al., 2019) or spectral based (Chen et al., 2019) methods.

Chen et al. (2019) improves the regularization of transfer learning from another angle. They propose Batch Spectral Shrinkage (BSS) to regularize spectral components of deep representations by penalizing small singular values. BSS is complementary to other regularizers. However, it doesn’t deal with the issue of knowledge preserving.

Our paper adopts the general idea of preserving knowledge by regularizing features of the source model. While unlike previous methods, we do not directly use the original knowledge provided by the source model. Instead, we disentangle the useful part for reference to avoid negative transfer. There main differences are summarized in Table 1.

Studies from other angles, such as sample selection (Ge & Yu, 2017; Ngiam et al., 2018; Jeon et al., 2020), dynamic computing (Guo et al., 2019), sparse transfer (Wang et al., 2019) and cross-modality transfer (Hu et al., 2020) are also important topics but out of this paper’s scope.

2.3 Representation disentanglement

The key assumption of representation disentanglement is that, a satisfactory representation should separate underlying factors of variations, which are compact, explanatory and independent with each other (Goodfellow et al., 2009; Bengio et al., 2013). Representation disentanglement has been widely applied in advanced generative algorithms such as Generative Adversarial Networks (Goodfellow et al., 2014) and Variational Autoencoder (Kingma & Welling, 2013). Some works investigate general disentangling methods for data generating, including InfoGAN (Chen et al., 2016), AC-GAN (Odena et al., 2017), \(\beta\)-VAE (Higgins et al., 2017), FactorVAE (Kim & Mnih, 2018) and so on.

Recently representation disentanglement is also demonstrated to be useful in tasks of unsupervised image-to-image translation and domain adaptation, which are more related to our work. Liu et al. (2018) proposes to perform joint representation disentanglement and domain adaptation with only attribute supervision available in the source domain. Liu et al. (2018) further generalizes the previous study to a unified feature disentanglement network. In this work, the data domain is first regarded as an interested underlying factor to be disentangled. Peng et al. (2019) improves above studies by employing class disentanglement and minimizing the mutual information between disentangled features to enhance the disentanglement further.

Our work is highly inspired and encouraged by the progress of domain information disentanglement (Liu et al., 2018) and disentangling techniques (Belghazi et al., 2018; Peng et al., 2019), while we are the first to utilize disentanglement methods to improve inductive transfer learning, aiming at (target) task specific feature disentanglement rather than domain invariant feature extraction in unsupervised domain adaptation.

2.4 Connections to relevant approaches

Regarding deep inductive transfer learning approaches, our can be categorized into the direction of SPAR on features (Zagoruyko & Komodakis, 2016; Li et al., 2019), aiming to reuse transferable features on the target task. Among the existing literature, the most relevant method to our work is DELTA (Li et al., 2019). Specifically, DELTA (Li et al., 2019) proposes to preserve useful features by evaluating the contribution (to the target task) of each feature channel independently. In contrast, our method adopts an advanced idea to disentangle the transferable ingredients by a more powerful disentangler. An important difference is that, DELTA assumes independence among feature channels, which is impractical for DNNs. In contrast, feature interaction can be taken into consideration in our method.

3 Preliminaries

3.1 Problem definition

In inductive transfer learning, we are given a model pre-trained on the source task, with the parameter vector \(\varvec{\omega ^0}\). For the desired task, the training set contains n tuples, each of which is denoted as \((\varvec{x}_i,y_i)\). \(\varvec{x}_i\) and \(y_i\) refers to the i-th example and its corresponding label.

Let’s further denote z as the function of the neural network and \(\varvec{\omega }\) as the parameter vector of the target network. We have the objective of structural risk minimization

$$\begin{aligned} \min _{\varvec{\omega }}\ \sum _{i=1}^n L_{{\textrm{ERM}}}(z( \varvec{x}_i, \varvec{\omega }), y_{i}) + \lambda \cdot \Omega (\varvec{\omega },\varvec{\omega }^0), \end{aligned}$$
(1)

where the first term is the empirical loss and the second is the regularization term. \(\lambda\) is the coefficient to balance the effect between data fitting and reducing over-fitting. We implement \(L_{\textrm{ERM}}\) with the cross entropy loss.

3.2 Regularizers for transfer learning

Recent studies in the deep learning paradigm show that SGD itself has the effect of implicit regularization that helps generalizing in over-parameterized regime (Soltanolkotabi et al., 2018). In addition, since fine-tuning is usually performed with a smaller learning rate and fewer epochs, it can be regarded as a form of implicit regularization towards the initial solution with good generalization properties (Liu et al., 2019). Besides, we give a brief introduction of state-of-the-art explicit regularizers for deep transfer learning.

\(\mathbf {L^2\ Penalty}\). The most common choice is the \(\mathrm {L^2}\) penalty with the form of \(\Vert \varvec{\omega }\Vert ^2_2\), also named weight decay in deep learning. From a Bayesian perspective, it refers to a Gaussian prior of the parameter with a zero mean. The shortcoming is that the meaningful initial point \(\varvec{\omega ^0}\) is ignored.

\(\mathbf {L^2}\)-\(\textbf{SP}\). Li et al. (2018) follows the idea of shrinking towards chosen targets instead of zero. They propose to use the starting point as the reference

$$\begin{aligned} \Omega (\varvec{\omega }) = \frac{\alpha }{2}\Vert \varvec{\omega _s}-\varvec{\omega ^0_s}\Vert ^2_2 + \frac{\beta }{2}\Vert \varvec{\omega _{{\overline{s}}}}\Vert ^2_2, \end{aligned}$$
(2)

where the first term refers to constraining the parameter of the part responsible for representation learning around the starting point, and the second is weight decay of the remaining part which is task specific. Since \(\varvec{\omega _{{\overline{s}}}}\) is general in all mentioned methods, we ignore it in following formulas.

\(\textbf{DistReg}\). Based on the idea of constraining the parameters around the pre-trained model, Gouk et al. (2020), Li and Zhang (2021) propose improved regularizers. Specifically, Gouk et al. (2020) use a projected stochastic subgradient method to guarantee the parameter constraints. They also suggest the maximum absolute row sum (MARS) distance as the measurement instead of commonly used Euclidean distance. The constraint can be formalized as

$$\begin{aligned} \Vert \varvec{\omega _s}-\varvec{\omega ^0_s}\Vert _{\infty } \le \gamma , \end{aligned}$$
(3)

where \(\gamma\) is the pre-defined maximum allowable distance.

\(\textbf{DELTA}\). Li et al. (2019) extends the framework of feature distillation (Romero et al., 2014; Zagoruyko & Komodakis, 2016) by incorporating an attention mechanism. They constrain 2-d activation maps with respect to different channels by different strengths according to their values to the target task. Given a tuple of training example \((\varvec{x}_i, y_i)\) and the distance metric between activation maps \(\varvec{D}\), the regularization is formulated as

$$\begin{aligned} \begin{aligned} \Omega {^i}(\varvec{\omega _s}) = \frac{\alpha }{2}\sum _{j=1}^C \textrm{W}_{j}(\varvec{\omega ^0_s}, \varvec{x}_i, y_i)\cdot \varvec{D}{^i_j}, \end{aligned} \end{aligned}$$
(4)

where C is the number of channels and \(\textrm{W}_{j}(\cdot )\) refers to the regularization weight assigned to the j-th channel. Specifically, each weight is independently evaluated by the performance drop when disabling that channel.

\(\textbf{BSS}\). Authors in Chen et al. (2019) propose Batch Spectral Shrinkage (\(\textrm{BSS}\)), towards penalizing untransferable spectral components of deep representations. They figure out that spectral components which are less transferable are those corresponding to relatively small singular values. They apply differentiable SVD to compute all singular values of a feature matrix and penalize the smallest k ones:

$$\begin{aligned} \begin{aligned} \Omega (\varvec{\omega _s}) = \alpha \sum _{i=1}^k \sigma ^2_{b+1-i}, \end{aligned} \end{aligned}$$
(5)

where all singular values [\(\sigma _1,\sigma _2\ldots ,\sigma _b\)] are in the descending order. \(\varvec{\omega ^0}\) is not involved as \(\textrm{BSS}\) doesn’t consider preserving the knowledge in the source model.

3.3 Learning disentangled features

Conceptually, disentangling is a special kind of feature transformation in the context of deep representation learning (Belghazi et al., 2018; Peng et al., 2019). Specifically, it involves a disentangler D, which is applied on the original feature \(f \in {\mathbb {R}}^d\) and meant to obtain a set of disentangled features \(\{f_i \in {\mathbb {R}}^k\}_{i=1}^{M}\) as

$$\begin{aligned} \begin{aligned} {f_1, f_2,\ldots , f_M = D(f),} \\ {\text {subject to } \textrm{SUM}(\{f_i\}_{i=1}^{M})=f,} \end{aligned} \end{aligned}$$
(6)

where \(\textrm{SUM}\) is an abstract operator, which can be instantiated by numerical summation, concatenation and so on. Accordingly, the dimension k of the disentangled features can be either the same or smaller than the original dimension d, depending on specific applications.

The purpose of learning disentangled features is to extract those interesting components from the original feature, which is usually noisy, in terms of the target task. For example, to achieve domain generalization, Peng et al. (2019) intends to extract domain-invariant information from original features learned from a source domain. In contrast, our work aims to extract target-relevant information from the noisy source features.

figure a

4 Target-awareness disentanglement

This section presents our proposed fine-tuning method, following the general framework in Eq (1) composed of the empirical loss \(L_{\textrm{ERM}}\) and regularizer \(\Omega\). In case that training samples are not adequate enough, a reasonable regularizer is usually essential to mitigate the risk of over-fitting. The main purpose of our work is to introduce proper inductive bias in the regularizer \(\Omega\).

Features extracted from the source model, which is usually pre-trained over a large scale dataset with diverse categories, are often noisy for a specific target task. Irrelevant ingredients w.r.t. the target task contained the general knowledge may lead to negative transfer. Our aim is to disentangle the positive ingredients (relevant to the target task) from the entire representation. To achieve this goal, three conditions: distinguishable, discriminative, and recoverable should be satisfied:

  • The positive part should be distinguishable from the negative part. In other words, two features with similar patterns or semantic relations should not be disentangled into different parts. This is the most crucial component in this framework.

  • The positive part should be discriminative on the target task. The aforementioned disentangler is usually symmetric, i.e. we can define either part as the “positive” if without external supervision. Therefore, an extra signal is needed to discriminate the positive part.

  • The original representation should be recoverable by the disentangled parts. For a disentangling operation, both the disentangled parts should not represent knowledge beyond the original representation. Otherwise, the transformation may result in non-generalizable features.

4.1 General definitions and notations

We first specify general definitions and notions used in our algorithm. Different with the main stream of disentanglement studies which try to separate various atomic attributes such as the color or angle, we care about disentangling components relevant to the target task from the whole representation produced by the source model. Formally, we disentangle the original representation \(FM_{ori} \in R^{C \times H \times W}\) obtained from the pre-trained model into the positive and negative part with the disentangler module D:

$$\begin{aligned} FM_{pos}, FM_{neg} = D(FM_{ori}), \end{aligned}$$
(7)

where \(FM_{pos}\) and \(FM_{neg}\) have the same shape with \(FM_{ori}\). For efficient estimation and optimization of the disentanglement, we further denote the mapping functions \({\mathcal{F}}^{\mathcal{C}}: R^{C \times H \times W} \rightarrow R^{C}\) and \({\mathcal{F}}^{\mathcal{S}}: R^{C \times H \times W} \rightarrow R^{H \times W}\), representing dimension reduction along the spatial and channel direction respectively. Therefore we get

$$\begin{aligned} f^{c}_{*} = {\mathcal{F}}^{\mathcal{C}}(FM_*), f^s_* = {\mathcal{F}}^{\mathcal{S}}(FM_*), \end{aligned}$$
(8)

where \(*\) refers to either pos or neg.

4.2 Implementation of disentanglers

The design of our disentangling is motivated by the following two assumptions, to which the ideally disentangled relevant (positive) and irrelevant (negative) part are expected to conform.

  • Non-overlapped Visual Attention. The visual attention can be regarded as a most interpretable kind of knowledge learned by DNNs (Zagoruyko & Komodakis, 2016). Intuitively, the disentangled two parts should pay attention to different visual regions. Otherwise, similar patterns are more likely to concurrently exist in different parts, implying an incomplete disentanglement.

  • Independent Semantic Representation. Another common assumption is that, patterns relevant to a same target object are likely to have dependence with each other, but not vice versa. For example, in order to recognize different dogs, a pattern on eyes often implies a pattern on noses, but not much likely to imply a pattern on indoor objects such as cups or floors.

Both aforementioned assumptions can be demonstrated by the example in Fig. 1. Next we describe the two methods in details.

Non-overlapped visual attention with Max-MMD. Imitating the visual attention mechanism of humans, we force the two parts to pay attention to different spatial regions within the original image. This is achieved by enlarging their statistical distributions along the spatial dimension. We use the Maximum Mean Discrepancy (MMD) to measure the distribution distance between the two spatial representations. Maximum Mean Discrepancy (MMD) is originally designed to test whether two distributions are the same (Gretton et al., 2012). Further, it’s also widely used as a criterion to measure the distance of two distributions in domain adaptation tasks (Pan et al., 2010; Tzeng et al., 2014; Long et al., 2017) and generative adversarial networks (Sutherland et al., 2016; Arbel et al., 2018). Under a commonly adopted Reproducing Kernel Hilbert Space (RKHS) assumption, the MMD can be represented as an unbiased approximation with the kernel form as follows.

Denoting \(X_s=\{x_s^1, x_s^2,\ldots ,x_s^{{m}}\}\) and \(X_t=\{x_t^1, x_t^2,\ldots ,x_t^{{n}}\}\) as random variable sets with distributions P and Q, an empirical estimate (Tzeng et al., 2014; Long et al., 2015) of the MMD between P and Q compares the square distance between the empirical kernel mean embeddings as

$$\begin{aligned} \begin{aligned} \textbf{MMD}(P, Q)&= \Vert \frac{1}{m}\sum _{i=1}^m k(x_s^i) - \frac{1}{n}\sum _{j=1}^n k(x_t^j) \Vert ^2, \end{aligned} \end{aligned}$$
(9)

where k refers to the kernel, as which a Gaussian radial basis function (RBF) is usually used in practice (Long et al., 2015; Louizos et al., 2016).

Our objective is to enlarge the MMD between the disentangled positive and negative part along the spatial dimension. Intuitively, this would explicitly encourage these two parts to recognize different regions of the input image. For stabler optimization, we minimize the negative exponent of the MMD as followed:

$$\begin{aligned} \begin{aligned} L^D_{di}=\lambda _{di}e^{-\textbf{MMD}(f^s_{pos}, f^s_{neg})}. \end{aligned} \end{aligned}$$
(10)

Independent semantic representation with Min-MI. In information theory, mutual information (MI) between two random variables quantifies the amount of information obtained about one through observing the other. Formally, the mutual information between random variables X and Z is defined as

$$\begin{aligned} \begin{aligned} I(X;Z) = D_{KL}({\mathbb {P}}_{XZ}\Vert {\mathbb {P}}_X \otimes {\mathbb {P}}_Z) \end{aligned} \end{aligned}$$
(11)

, where \({\mathbb {P}}_{X}\) and \({\mathbb {P}}_{Z}\) are the marginal distributions, and \({\mathbb {P}}_{XZ}\) is their joint distribution. In general, I(XZ) is non-negative and zero only when X and Z are independent.

Recent study (Higgins et al., 2017) demonstrated that disentanglement of interested factors can be enhanced by encouraging independence between them. Further, Peng et al. (2019) practically minimized the mutual information between disentangled features to strengthen class/domain disentanglement. Inspired by Peng et al. (2019), to achieve the goal of disentangling the original representation into two parts which are both meaningful and complementary, we minimize the mutual information between probability distributions of \(f^c_{pos}\) and \(f^c_{neg}\).

Since the exact computation of mutual information for high-dimensional data is rather hard, we leverage recent approach for efficient estimation of mutual information (Belghazi et al., 2018) by a neural network \(T_\theta\):

$$\begin{aligned} \begin{aligned} I(\widehat{X;Z})_n = \sup _{\theta \in \Theta }{\mathbb {E}}_{{\mathbb {P}}_{XZ}^{(n)}}[T_\theta ]-\textrm{log}\left({\mathbb {E}}_{{\mathbb {P}}_X^{(n)} \otimes \mathbb {{\hat{P}}}_Z^{(n)}}[e^{T_\theta }]\right). \end{aligned} \end{aligned}$$
(12)

Denoting the optimal solution as \({\hat{\theta }}\), the mutual information can be computed by Monte-Carlo approximation for a mini-batch of b examples as

$$\begin{aligned} \begin{aligned} I(X;Z;{\hat{\theta }})_b = \frac{1}{b}\sum _{i=1}^b T(x^i,z^i,{\hat{\theta }}) - \textrm{log} \left( \frac{1}{b}\sum _{i=1}^b e^{T(x^i,{\overline{z}}^i,{\hat{\theta }})} \right) \end{aligned} \end{aligned}$$
(13)

, where each \((x^i, z^i)\) is drawn from the joint distribution \({\mathbb {P}}_{XZ}\), and \({\overline{z}}^i\) is drawn from the marginal distribution \({\mathbb {P}}_Z\).

Mutual information estimation by Eq 12 is a maximizing problem by optimizing on \(T_{\theta }\), while our objective is to minimize the mutual information by learning the disentangled representations. Therefore, we adopt common practice to alternatively update \(T_{\theta }\) and D. Given \(T_{{\hat{\theta }}}\) estimated for the current distributions of \(f^c_{pos}\) and \(f^c_{neg}\), we update D with

$$\begin{aligned} \begin{aligned} L^D_{di}=\lambda _{di} I(f^c_{pos};f^c_{neg};{\hat{\theta }})_b \end{aligned} \end{aligned}$$
(14)

4.3 Reconstruction requirement

. As both the positive and negative part are trained by the flexible disentangler, it is easy to produce two parts of meaningless representations with the only objective of disentangling with Max-MMD or Min-MI. To ensure the disentanglement is restricted within the knowledge of the source model rather than an arbitrary transformation, we add the reconstruction requirement to constrain the disentangled results. Specifically, the disentangled positive and negative parts are required to be capable of recovering the original representation by point-wise addition:

$$\begin{aligned} \begin{aligned} L^D_{re}=\lambda _{re}\Vert FM_{pos} + FM_{neg} - FM_{ori} \Vert ^2_2. \end{aligned} \end{aligned}$$
(15)

4.4 Distinguishing the positive part

Since above representation disentanglement is actually symmetry for each part, an explicit signal is required to distinguish features which are useful for the target task. In particular, the selected layer for representation transfer is followed by a classifier consisting of a global pooling layer and a fully connected layer sequentially. A regular cross entropy loss is added to explicitly drive the disentangler to extract into the positive part components which are discriminative for the target task. Denoting the involved classifier as C, we have

$$\begin{aligned} \begin{aligned} L^D_{ce}=\lambda _{ce}\textbf{CrossEntropy}(C(f^c_{pos}), y_i). \end{aligned} \end{aligned}$$
(16)

4.5 Regularizing the disentangled representation

After the step of representation disentanglement, we perform fine-tuning over the target task. We regularize the distance between a feature map and its corresponding starting point. Quite different from previous feature map based regularizers as Romero et al. (2014), Zagoruyko and Komodakis (2016), Li et al. (2019), the starting point here is the disentangled positive part of the original representation. The regularization term corresponding to some example (\(\varvec{x}_i, y_i\)) becomes:

$$\begin{aligned} \Omega {^i}(\varvec{\omega _s}) = \frac{\alpha }{2}\Vert FM(\varvec{\omega _s}, \varvec{x}_i)- FM_{pos}(\varvec{\omega _s^0}, \varvec{\omega _{di}}, \varvec{x}_i) \Vert _2^2, \end{aligned}$$
(17)

where \(\varvec{\omega _{di}}\) refers to the parameter of the disentangler D which is frozen during the fine-tuning stage. The complete training procedure is presented in Algorithm 1.

4.6 Comparison between Max-MMD and Min-MI

Here we give a comparative discussion on the two methods to achieve disentanglement, which are Max-MMD and Min-MI. Though aiming at the same goal, these two choices are different from the following two aspects.

Assumption. The two methods are based on different assumptions. From the perspective of information theory, Max-MMD prefers exclusion, while Min-MI means independence. Although both have the effect of disentangling features, exclusion generally makes stronger assumption than independence in terms of feature disentanglement. For example, a pattern responsible for recognizing a specific color may be useful for both the foreground and background parts in an image. In that case, the requirement of non-overlap will be over-strict, while Min-MI only requires statistical independence of their occurrences.

Implementation. In terms of implementation, Max-MMD is more practical than Min-MI. This because MMD has been intensively studied for distribution comparison, and shows stable performance with commonly used kernels such as RBF. However, the calculation of mutual information on high-dimensional variables is still challenging in both theoretical and empirical perspectives (Lin et al., 2019). We find that, Min-MI requires non-trivial efforts of trial and error to obtain reasonable performance, regarding the choice of MI estimation network architecture and the corresponding optimization strategies.

Practical consideration Based on above analyses, both Max-MMD and Min-MI are general approaches (e.g., data-agnostic) and they can be used in combination for higher accuracy. While in applications where computational resources are limited, we recommend Max-MMD as an efficient solution, which requires much less effort in performance tuning.

5 Experiments

5.1 Image classification

5.1.1 Datasets

We select several popular transfer learning datasets to evaluate the effectiveness of our method.

Stanford dogs The Stanford Dogs (Khosla et al., 2011) dataset consists of images of 120 breeds of dogs, each of which containing 100 examples used for training and 72 for testing. It’s a subset of ImageNet.

MIT Indoor-67. MIT Indoor-67 (Quattoni & Torralba, 2009) is an indoor scene classification task consisting of 67 categories. There are 80 images for training and 20 for testing for each category.

CUB-200-2011 Caltech-UCSD Birds-200-2011 (Welinder et al., 2010) contains 11,788 images of 200 bird species from around the world. Each species is associated with a Wikipedia article and organized by scientific classification.

Flower-102. Flower-102 (Nilsback & Zisserman, 2008) consists of 102 flower categories. 1020 images are used for training and 6149 images for testing. Only 10 samples are provided for each category during training.

Stanford cars The Stanford Cars (Krause et al., 2013) dataset contains 16,185 images of 196 classes of cars. The data is split into 8,144 training and 8,041 testing images.

Textures Describable Textures Dataset (Cimpoi et al., 2014) is a texture database, containing 5640 images organized by 47 categories according to perceptual properties of textures.

5.1.2 Settings and hyperparameters

We implement transfer learning experiments based on ResNet (He et al., 2016). For MIT indoor-67, we use ResNet-50 pre-trained with large scale scene recognition dataset Places 365 (Zhou et al., 2017) as the source model. For remaining datasets, we use ImageNet pre-trained ResNet-50 as the source model. Input images are resized with the shorter edge being 256 and then random cropped to \(224\times 224\) during training.

For optimization, we first train 5 epochs to optimize the disentangler by Adam with the learning rate of 0.01. All involved hyperparameters are set to default values of \(\lambda _{di}=10^{-2}, \lambda _{ce}=10^{-3}, \lambda _{re}=10^{-2}\). The values are decided by 3-fold cross-validation on the CUB-200-2011 dataset. We find the selected values generally perform well on other datasets and therefore set them as default choices. Then we use SGD with the momentum of 0.9, batch size of 64 and initial learning rate of 0.01 for fine-tuning the target model. We train 40 epochs for each dataset. The learning rate is divided by 10 after 25 epochs. We run each experiment three times and report the average top-1 accuracy.

\(\textrm{TRED}\) (both with Max-MMD and Min-MI) is compared with state-of-the-art transfer learning regularizers including \(\mathrm {L^2\text {-}SP}\)(Li et al., 2018), \(\textrm{AT}\)(Zagoruyko & Komodakis, 2016), \(\textrm{DELTA}\)(Li et al., 2019), \(\textrm{BSS}\)(Chen et al., 2019) and \(\textrm{DistReg}\)(Li & Zhang, 2021). We perform 3-fold cross validation searching for the best hyperparameter \(\alpha\) in each experiment. For \(\mathrm {L^2\text {-}SP}\), \(\textrm{DELTA}\) and \(\textrm{TRED}\), the search space is [\(10^{-3}\), \(10^{-2}\), \(10^{-1}\)]. Although authors in \(\textrm{AT}\) and \(\textrm{BSS}\) recommended fixed values of \(\alpha\) (\(10^3\) for \(\textrm{AT}\) and \(10^{-3}\) for \(\textrm{BSS}\)), we also extend the search space to [\(10^2\), \(10^3\), \(10^4\)] for \(\textrm{AT}\) and [\(10^{-4}\), \(10^{-3}\), \(10^{-2}\)] for \(\textrm{BSS}\). For \(\textrm{DistReg}\), we adopt the recommended hyperparameters according to their published code (Li & Zhang, 2021).

Table 2 Comparison of top-1 accuracy (%) with different methods.

5.1.3 Results

While recent theoretical studies proved that weight decay actually has no regularization effect (Van Laarhoven, 2017; Golatkar et al., 2019) when combined with common used batch normalization, we use No Regularization as the most naive baseline, reevaluating \(\mathrm {L^2}\). From Table 2 we observe that \(\mathrm {L^2}\) does not outperform fine-tuning without any regularization. This may imply that deep transfer learning hardly benefits from regularizers of non-informative priors.

Advanced works (Zagoruyko & Komodakis, 2016; Li et al., 2018, 2019) adopt regularizers using the starting point of the reference for knowledge preserving. From the perspective of Bayes theory, these are equivalent to the informative prior which believes the knowledge contained in the source model, in the form of either parameters or behavior. Table 2 shows that these algorithms obtain significant improvements on some datasets such as Stanford Dogs and MIT indoor-67, where the target dataset is very similar to the source dataset. However, benefits are much less on other datasets such as CUB-200-2011, Flower-102 and Stanford Cars.

Table 2 illustrates that \(\textrm{TRED}\) consistently outperforms all above baselines over all evaluated datasets. It outperforms naive fine-tuning regularizer \(\mathrm {L^2}\) by more than 3% on average. Except for Stanford Dogs and MIT Indoor-67, improvements are still obvious even compared with state-of-the-art regularizers \(\mathrm {L^2\text {-}SP}\), \(\textrm{AT}\), \(\textrm{DELTA}\), \(\textrm{BSS}\) and \(\textrm{DistReg}\).

As for the choice of disentangling strategies, while \(\textrm{TRED}\) is generally superior to either \(\textrm{TRED}_{\textrm{MMD}}\) or \(\textrm{TRED}_{\textrm{MI}}\), the improvement looks not significant, implying they may both achieve nearly-optimal results (in our scenario) and highly overlap with each other. We speculate the reason may lie in two folds. Firstly, the required disentanglement is very coarse-grained and consequently much easier compared to general work in representation learning. In the transfer learning scenario, we only require the disentangled entire representation to be relevant to the target task, while its semantic units can still be entangled. Secondly, unlike the pure unsupervised learning fashion, supervised information is available in our case, which further reduces the challenge of learning disentangled features. Therefore, the disentangled features obtained by Max-MMD and Min-MI may highly overlap with each other and their combination obtains less remarkable further benefits.

To evaluate the scalability of our algorithm with more limited data, we conduct additional experiments on subsets of the standard dataset CUB-200-2011. Baseline methods include \(\mathrm {L^2}\), \(\textrm{BSS}\)(Chen et al., 2019), \(\mathrm {L^2\text {-}SP}\)(Li et al., 2018), \(\textrm{AT}\)(Zagoruyko & Komodakis, 2016), \(\textrm{DELTA}\)(Li et al., 2019) and \(\textrm{DistReg}\)(Li & Zhang, 2021). Specifically, we random sample 50%, 30% and 15% training examples for each category to construct new training sets. Results show that our proposed \(\textrm{TRED}\) achieves remarkable improvements compared with all competitors, as presented in Table 3.

Table 3 Comparison of top-1 accuracy (%) on CUB-200-2011 with respect to different sampling rates

5.2 Semantic segmentation

5.2.1 Datasets

PASCAL VOC (Everingham et al., 2012) and Cityscapes (Cordts et al., 2016) are used for evaluation on semantic segmentation. A brief introduction of the two datasets are as follows.

  • PASCAL VOC-2012 (Everingham et al., 2012). The dataset contains 20 diverse kinds of objects such as human, animal, vehicle and so on. The official data is composed of 1464 training images and 1449 validation images.

  • Cityscapes (Cordts et al., 2016) is a large-scale dataset consisting of video sequences of street scenes from 50 different cities. 2795 annotated images are available for training and 500 are for validation.

We perform bidirectional transfer learning tasks between the two datasets. To be more realistic, we use all training examples of one dataset for pre-training, and 20% randomly selected training examples of the other for fine-tuning.

5.2.2 Settings and hyperparameters

DeepLab-v3 (Chen et al., 2017) is employed as the architecture to train the semantic segmentation tasks. Following common practices, the ResNet-50 backbone of DeepLab-v3 is initialized with a checkpoint pre-trained on ImageNet classification. For both datasets, input images and annotations are resized to \(513 \times 513\). For pre-training on a full dataset, we optimize the model for 30000 iterations with an initial learning rate of 0.1, while for fine-tuning on a 20% subset, the model is updated for 15000 iterations with an initial learning rate of 0.01.

To adapt our method to segmentation models, we employ a classifier to achieve pixel-level discrimination to calculated \(L_{CE}^D\) in our disentangler. The classifier, consisting of a sequence of a lightweight Atrous Spatial Pyramid Pooling (ASPP) module, a \(3 \times 3\) Convolutional layer, a BatchNomr layer, a ReLU layer and finally a \(1 \times 1\) Convolutional layer, is with the same architecture of that used in DeepLab-v3. For the remaining components and hyperparameters, we follow the same configurations as those for classification. Max-MMD and Min-MI are combined with equal coefficients.

5.2.3 Results

The segmentation results are evaluated by mean IOU on the validation set. As shown in Table 4, \(\textrm{TRED}\) outperforms the baseline fine-tuning in both experiments.

Table 4 Experimental results on semantic segmentation tasks

6 Discussions

In this section, we dive deeper into the mechanism and experiment results to explain why target-awareness disentanglement provides better reference. Our analyses are based on the implementation of \(\textrm{TRED}\)-MMD. In subsection Representation Visualization, we show the effect of our method by visualizing attention maps and feature embeddings. In subsection Shrinking Towards True Behavior, we briefly discuss the theoretical understanding related with shrinkage estimation. Then we provide more statistical evidences to validate the advantage of the disentangled positive representation. In subsection Ablation Study, we empirically analyze why the disentanglement component is essential.

6.1 Representation visualization

Show Cases. Authors in Zagoruyko and Komodakis (2016) show that the spatial attention map plays a critical role in knowledge transfer. We demonstrate the effect of representation disentanglement by visualizing the attention map in Fig 2. As observed in typical cases from CUB-200-2011 and Stanford Cars, the original representations generated by the ImageNet pre-trained model usually contain a wide range of semantic features, such as objects or backgrounds, in addition to parts of birds. Our proposed disentangler is able to “purify" the interested concepts into the positive part, while the negative part pays more attention to the complementary constituent.

Fig. 2
figure 2

The effectiveness of representation disentanglement on CUB-200-2011 (left) and Stanford Cars (right). For each dataset, we select three typical cases for demonstration. In addition to the input image (a), we add spatial attention map onto the original image in column (b), c, and d using the input image and the desired representation of the last convolutional layer of ResNet-50. Specifically, b is the original representation generated by the ImageNet pre-trained model. c and d are the disentangled positive and negative part by \(\textrm{TRED}\)-MMD

Embedding Visualization. Since the most important change of our method is to use the disentangled rather than the original representation as the reference, we are interested in comparing the properties of these two representations on the target task. We visualize the original and disentangled feature representations of Flower-102 and MIT Indoor-67. The dimension of features is reduced along the spatial direction and then plotted in the 2D space using t-SNE embeddings. As illustrated in Fig. 3, deep representations derived by our proposed disentangler are separated more clearly than the original ones for different categories and clustered more tightly for samples of the same category.

Fig. 3
figure 3

Visualization of the original (a, c) and disentangled (b, d) feature representations by t-SNE. Different colors and markers are used to denote different categories

6.2 Shrinking towards true behavior

Recent work (Li et al., 2018) discusses the connection between their proposed \(\mathrm {L^2\text {-}SP}\) and classical statistical theory of shrinkage estimation (Efron & Morris, 1977). The key hypothesis is that shrinking towards a value which is close to the “true parameters" is more effective than an arbitrary one. Li et al. (2018) argues that the starting point is supposed to be more close to the “true parameters" than zero. Zagoruyko and Komodakis (2016), Li et al. (2019) regularize the feature rather than the parameter, which can be interpreted as shrinking towards the “true behavior". Our proposed \(\textrm{TRED}\) further improves them by explicitly disentangling “truer behavior" by utilizing the global distribution and supervision information of the target dataset. To support the claim, We provide some additional evidences as followed.

Reducing untransferable components. We compute singular eigenvectors and values of the deep representation by SVD. All singular values are sorted in descending order and plotted in Fig 4. Authors in Chen et al. (2019) demonstrate that the spectral components corresponding to smaller singular values are less transferable. They find that these less transferable components can be suppressed by involving more training examples. Interestingly, we find similar trends by the proposed representation disentanglement. As observed in Fig 4, smaller singular values of the disentangled positive representation are further reduced compared with the original representation. Fig 4 also shows the phenomenon that spectral components corresponding to larger singular values are increased, which does not exist in Chen et al. (2019). This is intuitively consistent to the hypothesis that features relevant to the target task are disentangled and strengthened.

Fig. 4
figure 4

Singular values of the original and disentangled deep representation

Robustness to regularization strength. We also provide an empirical evidence to illustrate the effect of “truer behavior" obtained by our proposed disentangler. The intuition is very straightforward that, if the behavior (representation) used as the reference is “truer", it is supposed to be more robust to the larger regularization strength. We compare with \(\textrm{DELTA}\) which uses the original representation as the reference. We select three transfer learning tasks for evaluation, which are Places365 \(\rightarrow\) MIT indoor-67, ImageNet \(\rightarrow\) Stanford Cars and Places365 \(\rightarrow\) Stanford Dogs. The regularization strength \(\alpha\) is gradually increased from 0.001 to 1. As illustrated in Fig 5, the performance of \(\textrm{DELTA}\) falls rapidly as \(\alpha\) increases, especially in ImageNet \(\rightarrow\) Stanford Cars and Places365 \(\rightarrow\) Stanford Dog, indicating that the regularizer using original representations as the reference suffers from negative transfer seriously. While \(\textrm{TRED}\) performs much more robust to the increasing of \(\alpha\).

Fig. 5
figure 5

Top-1 accuracy of transfer learning tasks corresponding to different regularization strength \(\alpha\)

6.3 Mechanism analysis

In this part, we briefly discuss about the necessity and relationship of the main components in our method. As our purpose is to disentangle the knowledge which is relevant to the target task, the supervision from the target dataset is of course necessary. We will first discuss whether a simple supervision is enough to “disentangle” the relevant knowledge from the source model. Next we will explain why the component of reconstruction is essential to ensure the effectiveness.

Why disentanglement is useful. It seems reasonable to obtain the discriminative representation only using the classifier corresponding to \(L^D_{ce}\). This is equivalent to perform a pre-adaptation upon the source model before fine-tuning. Unfortunately, such straightforward adaptation in a pure supervised manner is prone to over-fitting the limited target examples, as which the resultant representation is not adequate to serve as the prior knowledge. Encouraging the disentanglement between the relevant and irrelevant part, however, provides a distribution-level guarantee to simultaneously preserve the generalization capacity of the source model and adapt for the new task. That is achieved by restricting the integrity of the underlying data structure in a self-supervised way, e.g. disentanglement and reconstruction. To verify the hypothesis, we conduct an ablation study to compare the simpler framework without the disentanglement part, which performs direct transformation on the original representation. This version is denoted by \(\textrm{TRED}\)-.

We can observe in Table 5 that, all evaluated tasks get significant performance drop on the target task without disentanglement. A reasonable guess is that, disentangling helps preserve knowledge in the source model and restrain the representation transformation from over-fitting the classifier C. To verify this hypothesis, we compare the top-1 accuracy of ImageNet classification (the source task) between \(\textrm{TRED}\) and \(\textrm{TRED}\)-. Specifically, we train a random initialized classifier to recognize the category of ImageNet, using the fixed transformed representation as input. The top-1 accuracy of the pre-trained ResNet-50 model is 75.99%. As shown in Table 5, \(\textrm{TRED}\)- gets more performance drops on ImageNet than \(\textrm{TRED}\), indicating that representation disentanglement performs better in preserving the general knowledge learned by the source task.

Table 5 Top-1 accuracy (%) of the target and source task. \(\textrm{TRED}\)- refers to \(\textrm{TRED}\) without disentanglement

The interdependence between disentanglement and reconstruction. Disentanglement and reconstruction compose a complete self-supervised requirement. Without the force of disentanglement between the positive and negative part, the transformation will fall into the case of no disentanglement. Therefore, the reconstruction requirement alone is not capable of ensuring the preservation of the source knowledge of the learned positive part. Removing the reconstruction requirement results in an analogy situation, as the transformation can be arbitrary, pursuing to enlarge the discrepancy or independence between the two parts with no knowledge constraint. Intuitively, given that the positive part focus on the supervision information, the negative part can be easily learned to adapt either one (alone) of the two requirements. That causes both cases are almost equivalent to removing the entire disentangler.

In condition of the reconstruction guarantee, maximizing the attention discrepancy or minimizing the semantic independence between the positive and negative part encourages the two parts to focus on patterns related with different semantic concepts, with as less overlaps or interactions as possible. This is exact a form of representation disentanglement. For example the positive part is activated on heads and feathers of birds, while the negative part is activated on objects such as trees and wires as shown in Fig. 2. Moreover, this is achieved at a distribution level so that the semantic separation is consistent in the entire dataset.

6.4 Influences of hyperparameters

The hyperparameters \(\lambda _{di}\), \(\lambda _{ce}\) and \(\lambda _{re}\) should be properly chosen to ensure successful disentanglement. Otherwise, the disentangled features, which act as the reference for fine-tuning, will lead to decreased performance gain. To investigate how those hyperparameters affect transfer learning, we individually set each hyperparameter to \(100\times\) or \(0.01\times\) of its optimal value and evaluate on CUB-200-2011. As shown in Table 6, too large or small values of each hyperparameter shows lower accuracy than the optimal choice. Among these, using too small \(\lambda _{ce}\) leads to the worst performance, as the disentangled features might be not discriminative on the target data and is prone to negative transfer.

Table 6 Top-1 accuracy (%) on CUB-200-2011 with respect to different hyperparameters for Max-MMD

7 Conclusion

In this paper, we extend the study of catastrophic forgetting and negative transfer in inductive transfer learning. Specifically, we propose a novel approach \(\textrm{TRED}\) to regularize the disentangled deep representation, achieving accurate knowledge transfer. We succeed to implement the target-awareness disentanglement, by maximizing the Maximum Mean Discrepancy (MMD) on visual attentions and minimizing the Mutual Information (MI) on semantic features. Extensive experimental results on various real-world transfer learning datasets show that \(\textrm{TRED}\) (Max-MMD+Min-MI) significantly outperforms the state-of-the-art transfer learning regularizers. In low-resource scenarios, Max-MMD can be used as an efficient strategy. Moreover, we provide empirical analysis to verify that the disentangled target-awareness representation is closer to the expected “true behavior" of the target task.