Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

Predicting Probabilities of Error to Combine Quantization and Early Exiting: QuEE

Florence Regol
McGill University, ILLS, MILA
florence.robert-regol@mail.mcgill.ca &Joud Chataoui
McGill University, ILLS, MILA
joud.chataoui@mail.mcgill.ca &Bertrand Charpentier
Technical University of Munich
charpent@in.tum.de &Mark Coates
McGill University, ILLS, MILA
mark.coates@mcgill.ca &Pablo Piantanida
École de technologie supérieure, ILLS, MILA
pablo.piantanida@mila.quebec &Stephan Gunnemann
Technical University of Munich
guennemann@in.tum.de
International Laboratory on Learning SystemsQuebec Institute for Learning Algorithms
Abstract

Machine learning models can solve complex tasks but often require significant computational resources during inference. This has led to the development of various post-training computation reduction methods that tackle this issue in different ways, such as quantization which reduces the precision of weights and arithmetic operations, and dynamic networks which adapt computation to the sample at hand. In this work, we propose a more general dynamic network that can combine both quantization and early exit dynamic network: QuEE. Our algorithm can be seen as a form of soft early exiting or input-dependent compression. Rather than a binary decision between exiting or continuing, we introduce the possibility of continuing with reduced computation. This complicates the traditionally considered early exiting problem, which we solve through a principled formulation. The crucial factor of our approach is accurate prediction of the potential accuracy improvement achievable through further computation. We demonstrate the effectiveness of our method through empirical evaluation, as well as exploring the conditions for its success on 4 classification datasets.

1 Introduction

Large models, paired with transfer learning or fine-tuning, are becoming established as a dominant approach in machine learning [4, 8, 39, 7]. Consequently, reducing the inference cost of large pretrained models without (or with minimal) retraining is becoming increasingly important [7, 33, 1].

Refer to caption
Figure 1: Early exit decreases inference costs along the depth, while quantization decreases inference costs along the width. QuEE can integrate both, and learn to employ different computation reduction methods to classify different classes. On CIFAR-10, quantization is mainly used for automobiles and ships, while early exiting is mainly used for dogs and frogs.

There is a wide array of approaches to post-training computation reduction: quantization reduces the precision of stored weights and activations [36, 51, 1], distillation trains a smaller model to imitate a larger one [21, 13], pruning removes weights or units entirely [36], and dynamic networks can adapt the computation to the sample at hand [16]. These computation reduction methods vary significantly and offer different benefits. Our aim in this work is to develop a method that can combine different approaches to leverage their complementary strengths.

Although there have been proposals to combine approaches, including the application of fixed quantization to early exit networks [42, 28] and sample-aware quantization [35, 22, 43], there have been few efforts towards selecting and combining, on a per-sample basis, multiple post-training cost reduction techniques. Our goal is to design a unified framework that determines which combination of computation reduction methods should be employed for each sample, in a fully post-training setting. In our proposed dynamic network, computation reduction can be performed in two ways: both in depth and in width (see Figure 1). We achieve this by combining early exiting for depth adaptation with quantization for width adaptation. We opt for quantization as it currently offers more practical and superior compression results post-training without fine-tuning [32]. In contrast, pruning and distillation have received more attention in the fine-tuning setting [46, 10].

This allowance for adaptive per-sample quantization/early exiting significantly complicates the problem. Traditionally, in early exiting, the decision that must be made for each sample at each candidate exit point is whether to exit or continue. This permits the application of a simple binary threshold-based rule [23, 18, 2]. Alternatively, we can train exit-controller modules [40, 5, 25]. Such an approach is infeasible in the new setting due to the increased number of options (exit or quantize at multiple different levels). We address the problem with a principled approach to solve the more general dynamic network problem, drawing on insights from theoretical works [45, 27].

In particular, we observe that solving the dynamic network problem for fixed classifiers can be addressed by solving the task of predicting the probabilities of error of the candidate downstream (more computationally expensive) classifiers. Experimentally, we demonstrate that the proposed procedure can generate sufficiently accurate estimates of the probabilities of error, allowing the per-sample selection of appropriate levels of quantization and effective exit points.

Our paper makes the following contributions:

  1. 1.

    We introduce a novel and principled view of the dynamic network problem formulation, and reframe it as a task of predicting error probabilities.

  2. 2.

    To learn to predict the inaccessible error probabilities, we introduce a new method that involves discretizing the feature space and using empirical error approximations as our training targets.

  3. 3.

    Our unifying formulation allows us to to combine different computation reduction methods in a post training setting, leading to the introduction of a new learning framework: QuEE.

  4. 4.

    We empirically demonstrate the efficacy of our method, and we explore necessary conditions for its effectiveness, using 4 classification datasets.

2 Related Work

Dynamic neural networks.

Dynamic architectures adapt their computational graphs to the input[16, 36]. By adapting depth (a subset of layers are executed) or width (a subset of channels or neurons executed) for each sample, dynamic architectures can reduce computation during inference [16].

Early Exit (EE) networks are a class of dynamic depth architecture where a prediction is obtained at an intermediate layer and subsequent layers are skipped [2, 23, 18, 40, 25]. This is done by augmenting the network with intermediate inference modules at various layers. Early works propose architectures tailored for EE that are trained end-to-end, often paired with a simple thresholding mechanism as an exit rule [2, 23, 18]. This approach can provide significant efficiency savings, but end-to-end training can be impractical for large models [1, 40]. To address this, post-training EE methods that rely on a fixed pre-trained backbone have been introduced – they instead explore effective ways to train the inference modules and design more sophisticated gating mechanisms for the exit rule [40, 5, 25]. However, these often come with the drawback of having to repeat the training process for every operating point, making it impractical for use-cases where the computation budget changes over time.

Width-wise sample-adaptation can be achieved by selecting a subset of channels in CNNs  [24, 20]. A more general approach, compatible with transformer-based architectures, is SuperNets [34, 38, 19] where samples are dynamically routed through a subset of neurons at inference. Closer to our algorithm, works such as Wang et al. [47], Xia et al. [48] perform both depth and width adaptation via layer-skipping and channel-selection in convolution-based architectures. These approaches rely on trainable controllers that are optimized jointly with the underlying network [47, 48]. This makes them incompatible with large pre-trained models [29, 44].

Quantization.

Quantization is another effective way of speeding up inference. Weights, gradients and activations of a model are represented at lower bit resolutions [36]. Quantization-aware training (QAT) techniques quantize the network during training [41, 11], while post-training quantization (PTQ) is performed on a trained model [41, 11, 7]. This makes PTQ particularly appealing as a width-compression technique for our setting, i.e., working with large pre-trained models [1, 7]. However, quantization is typically not input-adaptive [41, 35, 22]. There are a few exceptions [22, 43, 35].  Hong et al. [22] and  Tian et al. [43] consider dynamic mixed-precision quantization for the specific problem of image super-resolution. Better suited for our more general setting, DQNet [35] explores dynamic mixed-precision quantization for image classification. In DQNet, the network is augmented with a small neural network called the bit controller whose role is to determine the bit resolution of each layer for a given sample. While Liu et al. [35], Hong et al. [22], and Tian et al. [43] all dynamically adapt bit precision on a per-sample basis, they encode the computational budget in their loss formulation, meaning that the algorithm needs to be retrained for every operating point.

Quantization of early exit networks.

Several works combine the adaptability of early exit networks with the efficacy of quantization by quantizing early exit networks [42, 28]. In [28], a pre-trained early exit network is first split into sections. Each section is then quantized separately using weight-clustering. The quantized network is fully retrained using knowledge distillation. Saxena and Roy [42] use a QAT approach, where the optimal per-layer quantization parameters are learnt during training. While these works combine quantization with early exiting, they both propose QAT-like approaches, and are thus unsuitable for large models. They are also not sample-adaptive along their widths, as a single static mixed-precision quantization is learnt for all samples [28, 42].

3 Problem Setting - Dynamic Networks

We consider a dynamic network setting with fixed classifiers. Suppose we have access to M𝑀Mitalic_M computational units: v1,,vMsubscript𝑣1subscript𝑣𝑀v_{1},\dots,v_{M}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT. These units can be composed to form L𝐿Litalic_L classifiers, f1,,fLsubscript𝑓1subscript𝑓𝐿f_{1},\dots,f_{L}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT, where f:𝒳𝒴:𝑓𝒳𝒴f:\mathcal{X}\to\mathcal{Y}italic_f : caligraphic_X → caligraphic_Y. For example, we may have f1=vM1v2v1subscript𝑓1subscript𝑣𝑀1subscript𝑣2subscript𝑣1f_{1}=v_{M-1}\circ v_{2}\circ v_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_v start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT ∘ italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∘ italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and f2=vMv2v1subscript𝑓2subscript𝑣𝑀subscript𝑣2subscript𝑣1f_{2}=v_{M}\circ v_{2}\circ v_{1}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_v start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ∘ italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∘ italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. The classifiers have associated costs of evaluation (c1,,cL)subscript𝑐1subscript𝑐𝐿(c_{1},\dots,c_{L})( italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_c start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ). The cost for classifier flsubscript𝑓𝑙f_{l}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is equal to the sum of the costs of its constituent computational units; we assume these costs to be unaffected by x𝑥xitalic_x.

To select which classifier will perform the inference for a given sample x𝑥xitalic_x, we employ a classifier-selector function S(x):𝒳[L]:𝑆𝑥𝒳delimited-[]𝐿S(x):\mathcal{X}\to[L]italic_S ( italic_x ) : caligraphic_X → [ italic_L ]. The classifier-selector function has a sample-dependent cost of evaluation, but it is negligible compared to the cost of the classifiers. We note that S(x)𝑆𝑥S(x)italic_S ( italic_x ) can be the result of a sequence of decisions, interspersed with computation111as is the case in an early exit dynamic network setting.. For example, S(x)𝑆𝑥S(x)italic_S ( italic_x ) may first choose to apply computational unit, v1subscript𝑣1v_{1}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, and then make a second decision, deciding whether to apply computational unit v2subscript𝑣2v_{2}italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT or v3subscript𝑣3v_{3}italic_v start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. The results of intermediate computation can be used when making the later decisions.

The goal is to learn the classifier-selector function S𝑆Sitalic_S that gives the best performance/computation cost trade-off, which can be quantified in various ways. In this work, we consider a cost-based 01 loss:

01c(x,y,S)=𝟙[fS(x)(x)y]+cS(x).subscript01𝑐𝑥𝑦𝑆1delimited-[]subscript𝑓𝑆𝑥𝑥𝑦subscript𝑐𝑆𝑥\displaystyle\ell_{01c}(x,y,S)=\mathbbm{1}[f_{S(x)}(x)\neq y]+c_{S(x)}.roman_ℓ start_POSTSUBSCRIPT 01 italic_c end_POSTSUBSCRIPT ( italic_x , italic_y , italic_S ) = blackboard_1 [ italic_f start_POSTSUBSCRIPT italic_S ( italic_x ) end_POSTSUBSCRIPT ( italic_x ) ≠ italic_y ] + italic_c start_POSTSUBSCRIPT italic_S ( italic_x ) end_POSTSUBSCRIPT . (1)

This loss is attractive because it is interpretable; we assign a cost of 1 to classification error and a cost of 0 to correct classification, and then cS(x)subscript𝑐𝑆𝑥c_{S(x)}italic_c start_POSTSUBSCRIPT italic_S ( italic_x ) end_POSTSUBSCRIPT controls the penalty associated with computation using the classifiers selected by S(x)𝑆𝑥S(x)italic_S ( italic_x ).

Our objective is then to parameterize S𝑆Sitalic_S and find the parameters that minimize this loss in expectation:

θ=superscript𝜃absent\displaystyle\theta^{*}=italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = argminθ𝔼XY[𝟙[fSθ(X)(X)Y]+cSθ(X)].subscriptargmin𝜃subscript𝔼𝑋𝑌delimited-[]1delimited-[]subscript𝑓subscript𝑆𝜃𝑋𝑋𝑌subscript𝑐subscript𝑆𝜃𝑋\displaystyle\operatorname*{arg\,min}_{\theta}\mathbb{E}_{XY}\left[\mathbbm{1}% [f_{S_{\theta}(X)}(X)\neq Y]+c_{S_{\theta}(X)}\right].start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT [ blackboard_1 [ italic_f start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_X ) end_POSTSUBSCRIPT ( italic_X ) ≠ italic_Y ] + italic_c start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_X ) end_POSTSUBSCRIPT ] . (2)

We can show from a straightforward extension of results in  [27] and [45] that the optimal classifier-selector function S(x)superscript𝑆𝑥S^{*}(x)italic_S start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) selects the classifier flsubscript𝑓𝑙f_{l}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT that has the smallest sum of evaluation cost clsubscript𝑐𝑙c_{l}italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and probability of making an error, which we denote by PE(fl(x)|x)𝑃𝐸conditionalsubscript𝑓𝑙𝑥𝑥PE(f_{l}(x)|x)italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) | italic_x ):

S(x)=superscript𝑆𝑥absent\displaystyle S^{*}(x)=italic_S start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) = argminSθ(x)𝔼Y|x[𝟙[fSθ(x)(x)Y]+cSθ(x)]subscriptargminsubscript𝑆𝜃𝑥subscript𝔼conditional𝑌𝑥delimited-[]1delimited-[]subscript𝑓subscript𝑆𝜃𝑥𝑥𝑌subscript𝑐subscript𝑆𝜃𝑥\displaystyle\operatorname*{arg\,min}_{S_{\theta}(x)}\mathbb{E}_{Y|x}\left[% \mathbbm{1}[f_{S_{\theta}(x)}(x)\neq Y]+c_{S_{\theta}(x)}\right]start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Y | italic_x end_POSTSUBSCRIPT [ blackboard_1 [ italic_f start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT ( italic_x ) ≠ italic_Y ] + italic_c start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT ] (3)
=\displaystyle== argminl[L]{cl+1Pr(Y=fl(x)|x)}subscriptargmin𝑙delimited-[]𝐿subscript𝑐𝑙1Pr𝑌conditionalsubscript𝑓𝑙𝑥𝑥\displaystyle\operatorname*{arg\,min}_{l\in[L]}\,\big{\{}c_{l}+1-\Pr(Y=f_{l}(x% )|x)\big{\}}start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_l ∈ [ italic_L ] end_POSTSUBSCRIPT { italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + 1 - roman_Pr ( italic_Y = italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) | italic_x ) } (4)
S(x)=superscript𝑆𝑥absent\displaystyle S^{*}(x)=italic_S start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) = argminl[L]{cl+PE(fl(x)|x)}.subscriptargmin𝑙delimited-[]𝐿subscript𝑐𝑙𝑃𝐸conditionalsubscript𝑓𝑙𝑥𝑥\displaystyle\operatorname*{arg\,min}_{l\in[L]}\,\{c_{l}+PE(f_{l}(x)|x)\}.start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_l ∈ [ italic_L ] end_POSTSUBSCRIPT { italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) | italic_x ) } . (5)

This solution provides insight into what the classifier-selector function S𝑆Sitalic_S should achieve:

Since the clsubscript𝑐𝑙c_{l}italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT are known, solving the dynamic network problem with fixed classifiers and associated 01c(x,y,S)subscript01𝑐𝑥𝑦𝑆\ell_{01c}(x,y,S)roman_ℓ start_POSTSUBSCRIPT 01 italic_c end_POSTSUBSCRIPT ( italic_x , italic_y , italic_S ) loss can be addressed by accurately predicting {PE(fl(x)|x)}l[L]𝑃𝐸conditionalsubscript𝑓𝑙𝑥𝑥for-all𝑙delimited-[]𝐿\{PE(f_{l}(x)|x)\}\,\forall l\in[L]{ italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) | italic_x ) } ∀ italic_l ∈ [ italic_L ] – the probability of making an error for each classifier.222For each pair of classifiers, fl1subscript𝑓subscript𝑙1f_{l_{1}}italic_f start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT and fl2subscript𝑓subscript𝑙2f_{l_{2}}italic_f start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT, with cl2>cl1subscript𝑐subscript𝑙2subscript𝑐subscript𝑙1c_{l_{2}}>c_{l_{1}}italic_c start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT > italic_c start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT, we must determine for each sample whether the differential of error probabilities: PE(fl1(x)|x)PE(fl2(x)|x)𝑃𝐸conditionalsubscript𝑓subscript𝑙1𝑥𝑥𝑃𝐸conditionalsubscript𝑓subscript𝑙2𝑥𝑥PE(f_{l_{1}}(x)|x)-PE(f_{l_{2}}(x)|x)italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) | italic_x ) - italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) | italic_x ) exceeds that of computational costs: cl2cl1subscript𝑐subscript𝑙2subscript𝑐subscript𝑙1c_{l_{2}}-c_{l_{1}}italic_c start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_c start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT.

4 Methodology - QuEE

Now that we have introduced the general framework for dynamic networks with fixed classifiers, we specify how each of the components, flsubscript𝑓𝑙f_{l}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, clsubscript𝑐𝑙c_{l}italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, and Sθ(x),subscript𝑆𝜃𝑥S_{\theta}(x),\mathcal{I}italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , caligraphic_I, are defined for our proposed architecture, QuEE, in which we mix different computation levels (through quantization) with early exiting.

Classifiers and costs (fl,cl)subscript𝑓𝑙subscript𝑐𝑙(f_{l},c_{l})( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ).

In our proposed dynamic network architecture, we can control both the number of blocks (a block consists of one or more network layers) that are evaluated (using early exiting) and the amount of computation conducted within each block. Each classifier is therefore defined by a “path” π𝜋\piitalic_π that traverses blocks of the network, where the number of steps in the path corresponds to the number of blocks evaluated. Denoting by bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT the computation level at block i𝑖iitalic_i, we can express a path that traverses e𝑒eitalic_e blocks before exiting as:

π𝜋\displaystyle\piitalic_π =b1be.absentsubscript𝑏1subscript𝑏𝑒\displaystyle=b_{1}\rightarrow\dots\rightarrow b_{e}.= italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → … → italic_b start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT . (6)

We denote by π[:j]\pi[:j]italic_π [ : italic_j ] the first j𝑗jitalic_j steps of this path (b1bjsubscript𝑏1subscript𝑏𝑗b_{1}\rightarrow\dots\rightarrow b_{j}italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → … → italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT). Given a possible number of exits E𝐸Eitalic_E and per-block computation levels B𝐵Bitalic_B, the number of paths we can take in the network is e=1EBesubscriptsuperscript𝐸𝑒1superscript𝐵𝑒\sum^{E}_{e=1}B^{e}∑ start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_e = 1 end_POSTSUBSCRIPT italic_B start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT. This corresponds to the total number of different classifiers that we could use for a given sample. We denote the set of all possible paths by 𝒫𝒫\mathcal{P}caligraphic_P.

We introduce the notation 𝐩^πi[0,1]|𝒴|subscript^𝐩𝜋𝑖superscript01𝒴\hat{\mathbf{p}}_{\pi i}\in[0,1]^{|\mathcal{Y}|}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π italic_i end_POSTSUBSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT | caligraphic_Y | end_POSTSUPERSCRIPT for the predicted probability vector of classifier fπ(xi)subscript𝑓𝜋subscript𝑥𝑖f_{\pi}(x_{i})italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). If it is not necessary, the index of the input xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is omitted and we only write 𝐩^πsubscript^𝐩𝜋\hat{\mathbf{p}}_{\pi}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT.

Available pre-computed quantities \mathcal{I}caligraphic_I.

At each candidate exit point, a gate gθj()superscriptsubscript𝑔𝜃𝑗g_{\theta}^{j}()italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( ) is used to make the decision about how to process the sample. Before obtaining the decision of gate gθj()superscriptsubscript𝑔𝜃𝑗g_{\theta}^{j}()italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( ) for j>1𝑗1j>1italic_j > 1, we evaluate the classifier fπ[:j1](xi)subscript𝑓𝜋delimited-[]:absent𝑗1subscript𝑥𝑖f_{\pi[:j{-}1]}(x_{i})italic_f start_POSTSUBSCRIPT italic_π [ : italic_j - 1 ] end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) that would be used if the sample were to exit. We also place a gate gθ1()superscriptsubscript𝑔𝜃1g_{\theta}^{1}()italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( ) at the very beginning of the network that can select the computation level of the first block of layers, but that cannot exit (gθ1(){b1,,bB}superscriptsubscript𝑔𝜃1subscript𝑏1subscript𝑏𝐵g_{\theta}^{1}()\in\{b_{1},\dots,b_{B}\}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( ) ∈ { italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT }).

Therefore, the information j(xi)subscript𝑗subscript𝑥𝑖\mathcal{I}_{j}(x_{i})caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) available to use as input for each gate gθj()subscriptsuperscript𝑔𝑗𝜃g^{j}_{\theta}(\cdot)italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) includes all the predicted probabilities that were evaluated at previous classifiers and the path followed so far (π[:j1]\pi[:j{-}1]italic_π [ : italic_j - 1 ]):

j(xi)={{𝐩^πi[:1],𝐩^πi[:2],𝐩^πi[:j1],πi[:j1]},j>1,j=1(no processed information is available to the first gate)\displaystyle\mathcal{I}_{j}(x_{i})=\begin{cases}\{\hat{\mathbf{p}}_{\pi_{i}[:% 1]},\hat{\mathbf{p}}_{\pi_{i}[:2]},\dots\hat{\mathbf{p}}_{\pi_{i}[:j-1]},\pi_{% i}[:j-1]\},\quad j>1\\ \varnothing,\quad j=1\quad\quad\text{(no processed information is available to% the first gate)}\end{cases}caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = { start_ROW start_CELL { over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : 1 ] end_POSTSUBSCRIPT , over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : 2 ] end_POSTSUBSCRIPT , … over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] end_POSTSUBSCRIPT , italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] } , italic_j > 1 end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL ∅ , italic_j = 1 (no processed information is available to the first gate) end_CELL start_CELL end_CELL end_ROW (7)
Classifier-selector function Sθ(x)subscript𝑆𝜃𝑥S_{\theta}(x)italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ).

In our setting, the classifier-selector function S()𝑆S()italic_S ( ) is decomposed into E1𝐸1E{-}1italic_E - 1 gating functions, gθj()subscriptsuperscript𝑔𝑗𝜃g^{j}_{\theta}()italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ), for j=1,,E1𝑗1𝐸1j=1,\dots,E-1italic_j = 1 , … , italic_E - 1. The j𝑗jitalic_j-th gate can decide to either 1) exit at block j𝑗jitalic_j, or 2) choose one of the B𝐵Bitalic_B levels of computation for the next block, i.e., gθj():j{0,,bB}:subscriptsuperscript𝑔𝑗𝜃subscript𝑗0subscript𝑏𝐵g^{j}_{\theta}(\cdot):\mathcal{F}_{j}\to\{0,\dots,b_{B}\}italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) : caligraphic_F start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT → { 0 , … , italic_b start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT }, where 00 indicates exit. We do not specify yet the input space jsubscript𝑗\mathcal{F}_{j}caligraphic_F start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT of the gating functions as it relates to the available information – we present it in the next section.

The gate gθj()subscriptsuperscript𝑔𝑗𝜃g^{j}_{\theta}()italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ) decides the action taken at the j𝑗jitalic_j-th step:

bj+1={gθj() if gθj(){b1,,bB}None if gθj()=0 (exit)subscript𝑏𝑗1casessubscriptsuperscript𝑔𝑗𝜃 if subscriptsuperscript𝑔𝑗𝜃subscript𝑏1subscript𝑏𝐵𝑁𝑜𝑛𝑒 if subscriptsuperscript𝑔𝑗𝜃0 (exit)\displaystyle b_{j+1}=\begin{cases}g^{j}_{\theta}(\cdot)&\text{ if }g^{j}_{% \theta}(\cdot)\in\{b_{1},\dots,b_{B}\}\\ None&\text{ if }g^{j}_{\theta}(\cdot)=0\text{ (exit) }\end{cases}italic_b start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT = { start_ROW start_CELL italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) end_CELL start_CELL if italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) ∈ { italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT } end_CELL end_ROW start_ROW start_CELL italic_N italic_o italic_n italic_e end_CELL start_CELL if italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) = 0 (exit) end_CELL end_ROW (8)

The path πisubscript𝜋𝑖\pi_{i}italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT selected for a given sample xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is then determined by this sequence of decisions, with termination when one of the gates decides to exit: πi=gθ1()gθe()=0subscript𝜋𝑖subscriptsuperscript𝑔1𝜃subscriptsuperscript𝑔𝑒𝜃0\pi_{i}=g^{1}_{\theta}(\cdot)\to\dots\to g^{e}_{\theta}(\cdot)=0italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_g start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) → … → italic_g start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) = 0, gθj()0j<e.subscriptsuperscript𝑔𝑗𝜃0for-all𝑗𝑒g^{j}_{\theta}(\cdot)\neq 0\,\,\,\forall\,\,j<e.italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) ≠ 0 ∀ italic_j < italic_e . The selected path is then iteratively constructed as we successively evaluate gθ()subscript𝑔𝜃g_{\theta}(\cdot)italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ).

QuEE Sθ(x)subscript𝑆𝜃𝑥S_{\theta}(x)italic_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ).
Refer to caption
Figure 2: Depiction of how QuEE predicts the next step at inference. hθsubscript𝜃h_{\theta}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is evaluated for each considered paths (2 paths shown in grey were not sampled), then the algorithm takes a step towards the path πsuperscript𝜋\pi^{*}italic_π start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT minimizing the predicted loss 01csubscript01𝑐\ell_{01c}roman_ℓ start_POSTSUBSCRIPT 01 italic_c end_POSTSUBSCRIPT. In this example, at gate j𝑗jitalic_j π2superscript𝜋2\pi^{2}italic_π start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT has been identified as the best path.

As presented in the previous section, a successful classifier-selector will accurately predict the probability of future classifiers fπsubscript𝑓𝜋f_{\pi}italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT being incorrect (PE(fπ|x)𝑃𝐸conditionalsubscript𝑓𝜋𝑥PE(f_{\pi}|x)italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x )) and select the classifier with the lowest combined cost and probability of error (see Eqn 5).

Therefore, we propose to explicitly learn to predict the probability of error pe^θ(fπ|x)=mPE(fπ|x)superscriptmsubscript^𝑝𝑒𝜃conditionalsubscript𝑓superscript𝜋𝑥𝑃𝐸conditionalsubscript𝑓superscript𝜋𝑥\widehat{pe}_{\theta}(f_{\pi^{\prime}}|x)\stackrel{{\scriptstyle\mathclap{% \mbox{\tiny{m}}}}}{{=}}PE(f_{\pi^{\prime}}|x)over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) start_RELOP SUPERSCRIPTOP start_ARG = end_ARG start_ARG m end_ARG end_RELOP italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) of each potential path π𝒫π[:j1]superscript𝜋subscript𝒫𝜋delimited-[]:absent𝑗1\pi^{\prime}\in\mathcal{P}_{\pi[:j-1]}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π [ : italic_j - 1 ] end_POSTSUBSCRIPT at each decision gate gθj(j)subscriptsuperscript𝑔𝑗𝜃subscript𝑗g^{j}_{\theta}(\mathcal{I}_{j})italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ). Each gate is therefore equipped with a predictor module hθj(j(x),π)=pe^θ(fπ|x)subscriptsuperscript𝑗𝜃subscript𝑗𝑥superscript𝜋subscript^𝑝𝑒𝜃conditionalsubscript𝑓superscript𝜋𝑥h^{j}_{\theta}(\mathcal{I}_{j}(x),\pi^{\prime})=\widehat{pe}_{\theta}(f_{\pi^{% \prime}}|x)italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) that takes as input jsubscript𝑗\mathcal{I}_{j}caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and the future path to predict πsuperscript𝜋\pi^{\prime}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. We can then select the step that would lead to the predicted optimal path, i.e., gθj(x)=π^[j1:j]g^{j}_{\theta}(x)=\hat{\pi}^{*}[j{-}1:j]italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = over^ start_ARG italic_π end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT [ italic_j - 1 : italic_j ], where

π^(x)superscript^𝜋𝑥\displaystyle\hat{\pi}^{*}(x)over^ start_ARG italic_π end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) =argminπ𝒫π[:j1]{cπ+pe^θ(fπ|x)},absentsubscriptargminsuperscript𝜋subscript𝒫𝜋delimited-[]:absent𝑗1subscript𝑐superscript𝜋subscript^𝑝𝑒𝜃conditionalsubscript𝑓superscript𝜋𝑥\displaystyle=\operatorname*{arg\,min}_{\pi^{\prime}\in\mathcal{P}_{\pi[:j-1]}% }\,\big{\{}c_{\pi^{\prime}}+\widehat{pe}_{\theta}(f_{\pi^{\prime}}|x)\big{\}},= start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π [ : italic_j - 1 ] end_POSTSUBSCRIPT end_POSTSUBSCRIPT { italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) } ,
for pe^θ(fπ|x)=hθj(j(x),π).subscript^𝑝𝑒𝜃conditionalsubscript𝑓superscript𝜋𝑥subscriptsuperscript𝑗𝜃subscript𝑗𝑥superscript𝜋\displaystyle\,\,\widehat{pe}_{\theta}(f_{\pi^{\prime}}|x)=h^{j}_{\theta}(% \mathcal{I}_{j}(x),\pi^{\prime})\,.over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) = italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) . (9)

This has a complexity of O(EBE)𝑂𝐸superscript𝐵𝐸O(EB^{E})italic_O ( italic_E italic_B start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT ) because we must infer for each potential path π𝒫πsuperscript𝜋subscript𝒫𝜋\pi^{\prime}\in\mathcal{P}_{\pi}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT. If this is prohibitive, we can sample a subset of paths to evaluate, with πU[𝒫π]similar-tosuperscript𝜋𝑈delimited-[]subscript𝒫𝜋\pi^{\prime}\sim U[\mathcal{P}_{\pi}]italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ italic_U [ caligraphic_P start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ]. An alternative approach is to directly predict the next step to avoid the exponential complexity in inference, but this also has some drawbacks. We explore the trade-off in Appendix 7.3. A notable advantage of our approach is the complete decoupling of the learning procedure from the assignment of costs to classifiers. The costs cπsubscript𝑐superscript𝜋c_{\pi^{\prime}}italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT are determined at inference and can be modified without necessitating any retraining or adjustments.

4.1 Approximating the probability of error PE(fπ|x)𝑃𝐸conditionalsubscript𝑓𝜋𝑥PE(f_{\pi}|x)italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x )

Our methodology involves learning to predict the performance we could achieve given more computational power. This would motivate using PE(fπ|x)𝑃𝐸conditionalsubscript𝑓𝜋𝑥PE(f_{\pi}|x)italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) as the target values during training, but we do not have access to these quantities. Therefore, we must rely on approximations to set the learning targets for pe^θ(fπ|x)subscript^𝑝𝑒𝜃conditionalsubscript𝑓𝜋𝑥\widehat{pe}_{\theta}(f_{\pi}|x)over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ). There is existing work on estimating these quantities [14, 12], but the proposed procedures tend to be only reliable for high-performing classifiers (we verified this through experiment). The intermediate classifiers at lower blocks are often relatively inaccurate.

We base our approximations on the assumption that PE(fπ|xi)𝑃𝐸conditionalsubscript𝑓𝜋subscript𝑥𝑖PE(f_{\pi}|x_{i})italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is smooth over some transformation of the space tπ(x)subscript𝑡𝜋𝑥t_{\pi}(x)italic_t start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_x ). This allows us to discretize the space into K𝐾Kitalic_K partitions {Q1,,QK}subscript𝑄1subscript𝑄𝐾\{Q_{1},\dots,Q_{K}\}{ italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_Q start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT } for which, under our assumption, the resulting distances |PE(fπ|xi)1pQxQPE(fπ|x)P(dx)||PE(f_{\pi}|x_{i})-\frac{1}{p_{Q}}\int_{x\in Q}PE(f_{\pi}|x)P(dx)|| italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT end_ARG ∫ start_POSTSUBSCRIPT italic_x ∈ italic_Q end_POSTSUBSCRIPT italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) italic_P ( italic_d italic_x ) | are small for all xiQsubscript𝑥𝑖𝑄x_{i}\in Qitalic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_Q. Here we have introduced pQQP(dx)subscript𝑝𝑄subscript𝑄𝑃𝑑𝑥p_{Q}\triangleq\int_{Q}P(dx)italic_p start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ≜ ∫ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_P ( italic_d italic_x ).

We can then compute the empirical estimator of 1pQQPE(fπ|x)P(dx)1subscript𝑝𝑄subscript𝑄𝑃𝐸conditionalsubscript𝑓𝜋𝑥𝑃𝑑𝑥\frac{1}{p_{Q}}\int_{Q}PE(f_{\pi}|x)P(dx)divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT end_ARG ∫ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) italic_P ( italic_d italic_x ) for each partition using a validation set. Assume that we have m𝑚mitalic_m samples {xj,yj}j=1msuperscriptsubscriptsubscript𝑥𝑗subscript𝑦𝑗𝑗1𝑚\{x_{j},y_{j}\}_{j=1}^{m}{ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT in the validation set and denote by mQsubscript𝑚𝑄m_{Q}italic_m start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT the number of samples in Q𝑄Qitalic_Q, i.e. mQ=j=1m𝟙[xjQ]subscript𝑚𝑄superscriptsubscript𝑗1𝑚1delimited-[]subscript𝑥𝑗𝑄m_{Q}=\sum_{j=1}^{m}\mathbbm{1}[x_{j}\in Q]italic_m start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT blackboard_1 [ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ italic_Q ]. Then:

PE(fπ|xi)𝑃𝐸conditionalsubscript𝑓𝜋subscript𝑥𝑖\displaystyle PE(f_{\pi}|x_{i})italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) 1Pr(Q)QPE(fπ|x)P(dx),forxiQ,formulae-sequenceabsent1Pr𝑄subscript𝑄𝑃𝐸conditionalsubscript𝑓𝜋𝑥𝑃𝑑𝑥forsubscript𝑥𝑖𝑄\displaystyle\approx\frac{1}{\Pr(Q)}\int_{Q}PE(f_{\pi}|x)P(dx),\quad\mathrm{% for}\,\,x_{i}\in Q\,,≈ divide start_ARG 1 end_ARG start_ARG roman_Pr ( italic_Q ) end_ARG ∫ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) italic_P ( italic_d italic_x ) , roman_for italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_Q , (10)
1mQj=1m𝟙[fπ(xj)yj,xjQ],absent1subscript𝑚𝑄superscriptsubscript𝑗1𝑚1delimited-[]formulae-sequencesubscript𝑓𝜋subscript𝑥𝑗subscript𝑦𝑗subscript𝑥𝑗𝑄\displaystyle\approx\frac{1}{m_{Q}}\sum_{j=1}^{m}\mathbbm{1}[f_{\pi}(x_{j})% \neq y_{j},x_{j}\in Q],≈ divide start_ARG 1 end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT blackboard_1 [ italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≠ italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ italic_Q ] , (11)
pe~(fπ|xi) our approximation used to train pe^θ(fπ|x).absent~𝑝𝑒conditionalsubscript𝑓𝜋subscript𝑥𝑖 our approximation used to train subscript^𝑝𝑒𝜃conditionalsubscript𝑓𝜋𝑥\displaystyle\triangleq\widetilde{pe}(f_{\pi}|x_{i})\text{ our approximation % used to train }\widehat{pe}_{\theta}(f_{\pi}|x).≜ over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) our approximation used to train over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) . (12)

In practice, we discretize the space using a clustering algorithm and we use the predicted probability vectors as the clustering features (tπ(x)=𝐩^πsubscript𝑡𝜋𝑥subscript^𝐩𝜋t_{\pi}(x)=\hat{\mathbf{p}}_{\pi}italic_t start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_x ) = over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT). We provide empirical results and qualitative analysis to motivate this choice in Appendix 7.2 and in the result section.

4.2 Modeling the probability of error with hθj(j(x),π)subscriptsuperscript𝑗𝜃subscript𝑗𝑥superscript𝜋h^{j}_{\theta}(\mathcal{I}_{j}(x),\pi^{\prime})italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT )

The learning task of hθj(j(x),π)subscriptsuperscript𝑗𝜃subscript𝑗𝑥superscript𝜋h^{j}_{\theta}(\mathcal{I}_{j}(x),\pi^{\prime})italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) is to predict the ground truth approximation pe~(fπ|x)~𝑝𝑒conditionalsubscript𝑓𝜋𝑥\widetilde{pe}(f_{\pi}|x)over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) of future classifiers fπsubscript𝑓superscript𝜋f_{\pi^{\prime}}italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Each classifier is encoded by its associated path πsuperscript𝜋\pi^{\prime}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, and we have access to past predicted probability vectors {𝐩^πi[:1],𝐩^πi[:2],𝐩^πi[:j]}subscript^𝐩subscript𝜋𝑖delimited-[]:absent1subscript^𝐩subscript𝜋𝑖delimited-[]:absent2subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗\{\hat{\mathbf{p}}_{\pi_{i}[:1]},\hat{\mathbf{p}}_{\pi_{i}[:2]},\dots\hat{% \mathbf{p}}_{\pi_{i}[:j]}\}{ over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : 1 ] end_POSTSUBSCRIPT , over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : 2 ] end_POSTSUBSCRIPT , … over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j ] end_POSTSUBSCRIPT } through j(xi)subscript𝑗subscript𝑥𝑖\mathcal{I}_{j}(x_{i})caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). We cannot employ a complex architecture for hθj(j(x),π)subscriptsuperscript𝑗𝜃subscript𝑗𝑥superscript𝜋h^{j}_{\theta}(\mathcal{I}_{j}(x),\pi^{\prime})italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) because the goal is to reduce computational overhead.

We base the decision on the current probability vector 𝐩^πi[:j]subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗\hat{\mathbf{p}}_{\pi_{i}[:j]}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j ] end_POSTSUBSCRIPT and the previous one 𝐩^πi[:j1]subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗1\hat{\mathbf{p}}_{\pi_{i}[:j-1]}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] end_POSTSUBSCRIPT. We additionally extract helpful statistics such as entropy H(𝐩^πi[:j]),H(𝐩^πi[:j1])𝐻subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗𝐻subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗1H(\hat{\mathbf{p}}_{\pi_{i}[:j]}),H(\hat{\mathbf{p}}_{\pi_{i}[:j-1]})italic_H ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j ] end_POSTSUBSCRIPT ) , italic_H ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] end_POSTSUBSCRIPT ) and maximum probability, max(𝐩^πi[:j]),max(𝐩^πi[:j1])subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗1\max(\hat{\mathbf{p}}_{\pi_{i}[:j]}),\max(\hat{\mathbf{p}}_{\pi_{i}[:j-1]})roman_max ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j ] end_POSTSUBSCRIPT ) , roman_max ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] end_POSTSUBSCRIPT ). These are often used in early exiting algorithms [2, 18, 23, 49]. The final vector ui2|𝒴|+4u_{i}\in{}^{2|\mathcal{Y}|+4}italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ start_FLOATSUPERSCRIPT 2 | caligraphic_Y | + 4 end_FLOATSUPERSCRIPT that encodes all the input features is:

uisubscript𝑢𝑖\displaystyle u_{i}italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =𝐩^πi[:j]||𝐩^πi[:j1]||H(𝐩^πi[:j])||H(𝐩^πi[:j1])||max(𝐩^πi[:j1])||max(𝐩^πi[:j]).absentsubscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗1𝐻subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗𝐻subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗1subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗1subscript^𝐩subscript𝜋𝑖delimited-[]:absent𝑗\displaystyle=\hat{\mathbf{p}}_{\pi_{i}[:j]}\,||\,\hat{\mathbf{p}}_{\pi_{i}[:j% -1]}\,||\,H(\hat{\mathbf{p}}_{\pi_{i}[:j]})\,||\,H(\hat{\mathbf{p}}_{\pi_{i}[:% j-1]})\,||\,\max(\hat{\mathbf{p}}_{\pi_{i}[:j-1]})\,||\,\max(\hat{\mathbf{p}}_% {\pi_{i}[:j]})\,.= over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j ] end_POSTSUBSCRIPT | | over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] end_POSTSUBSCRIPT | | italic_H ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j ] end_POSTSUBSCRIPT ) | | italic_H ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] end_POSTSUBSCRIPT ) | | roman_max ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j - 1 ] end_POSTSUBSCRIPT ) | | roman_max ( over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ : italic_j ] end_POSTSUBSCRIPT ) . (13)

As for the path πsuperscript𝜋\pi^{\prime}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, we encode it in a fixed length vector pisubscript𝑝𝑖p_{i}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of size E+1𝐸1E{+}1italic_E + 1 by padding shorter paths with 0s, and add the number of layers evaluated at the end:

pisubscript𝑝𝑖\displaystyle p_{i}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =[b1,b2,,be,0,0,e].E+1\displaystyle=[b_{1},b_{2},\dots,b_{e},0,0,e]\in{}^{E+1}.= [ italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT , 0 , 0 , italic_e ] ∈ start_FLOATSUPERSCRIPT italic_E + 1 end_FLOATSUPERSCRIPT . (14)

Our decision architecture is a simple 2-layer network:

pe^θ(fπ|x)=sigmoid(NN(NN(ui||pi))),\displaystyle\widehat{pe}_{\theta}(f_{\pi^{\prime}}|x)=\mathrm{sigmoid}(NN(NN(% u_{i}||p_{i}))),over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) = roman_sigmoid ( italic_N italic_N ( italic_N italic_N ( italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | | italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ) , (15)

and we train each predictor module hθj(j(x),π)subscriptsuperscript𝑗𝜃subscript𝑗𝑥superscript𝜋h^{j}_{\theta}(\mathcal{I}_{j}(x),\pi^{\prime})italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) with a standard MSE loss:

(x,π)=||hθj(j(x),π)pe~(fπ|x))||2.\displaystyle\mathcal{L}(x,\pi^{\prime})=||h^{j}_{\theta}(\mathcal{I}_{j}(x),% \pi^{\prime})-\widetilde{pe}(f_{\pi^{\prime}}|x))||_{2}.caligraphic_L ( italic_x , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = | | italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT . (16)

5 Experiments

We present results on image classification tasks: ImageNet [6], CIFAR10 and CIFAR100 [31], and SVHN [37], for different backbone architectures: T2T-ViT [50] pre-trained on ImageNet and transfer-learnt on the datasets, and ViT-14 [9], pre-trained as a foundation model using DinoV2 [39].

5.1 Obtaining the computation units for the classifiers fπsubscript𝑓𝜋f_{\pi}italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT

We start by augmenting the pre-trained backbone with randomly initialized classifiers at predefined layers (See App. 7.1 for details about exit placement). The classifiers have the same architecture as the final inference head. We train intermediate classifiers until convergence, keeping the backbone frozen. We then quantize the network at various bit-widths corresponding to our choices of computation levels, {b1,,bB}subscript𝑏1subscript𝑏𝐵\{b_{1},\dots,b_{B}\}{ italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT }, as follows. The highest bit-width in bBsubscript𝑏𝐵b_{B}italic_b start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT is chosen to be the lowest value that maintains the accuracy of the final classifier (usually around 8 bits). We then add decreasing bit-widths to bBsubscript𝑏𝐵b_{B}italic_b start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT as long as the accuracy of the final classifier is within 5% of the original performance. Typically, we obtained at most B=4𝐵4B=4italic_B = 4. We perform the quantization at each bit-width using PTQ4ViT [51], a state-of-art PTQ algorithm for transformers (See Appendix 7.1.2 for an in-depth description). The network obtained is thus a multi-quantization architecture comprised of L blocks augmented with an inference head, where each block consists of the backbone layer quantized at the B levels.

QuEE components We finally augment the multi-quantization network with gates {gθj}jEsubscriptsuperscriptsubscript𝑔𝜃𝑗𝑗𝐸\{g_{\theta}^{j}\}_{j\in E}{ italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_E end_POSTSUBSCRIPT at every exit. Each gate gθjsuperscriptsubscript𝑔𝜃𝑗g_{\theta}^{j}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT contains a small learnable predictor module hθjsuperscriptsubscript𝜃𝑗h_{\theta}^{j}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT. The first gate gθ1superscriptsubscript𝑔𝜃1g_{\theta}^{1}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT is placed at the start of the network and always routes to the highest computation level.

5.2 Training the predictor modules hθjsuperscriptsubscript𝜃𝑗h_{\theta}^{j}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT

Before training the predictor modules, we fit a K-means algorithm for every path in the network in order to obtain discretized predictor targets, where K𝐾Kitalic_K is a hyperparameter. In practice, we observe that paths that move from a low bit resolution to a higher bit resolution result in subpar performance. We thus do not consider these paths during training and inference. This effectively reduces the set of paths 𝒫𝒫\mathcal{P}caligraphic_P by a factor of two. If the resulting |𝒫|>50𝒫50|\mathcal{P}|>50| caligraphic_P | > 50, we randomly sample 50 paths, both during training and at inference. For each valid path π𝜋\piitalic_π, we obtain K𝐾Kitalic_K clusters by clustering the predicted probability vectors 𝐩^π[0,1]|𝒴|subscript^𝐩𝜋superscript01𝒴\hat{\mathbf{p}}_{\pi}\in[0,1]^{|\mathcal{Y}|}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT | caligraphic_Y | end_POSTSUPERSCRIPT using the validation set. For each cluster Q𝑄Qitalic_Q we compute pe~(fπ,Q|xi)~𝑝𝑒conditionalsubscript𝑓𝜋𝑄subscript𝑥𝑖\widetilde{pe}(f_{\pi,Q}|x_{i})over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π , italic_Q end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) using equation 11. At training, we obtain the discretized targets by predicting the appropriate cluster Qisuperscript𝑄𝑖Q^{i}italic_Q start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT for 𝐩^πsubscript^𝐩superscript𝜋\hat{\mathbf{p}}_{\pi^{\prime}}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and using pe~(fπ,Qi|xi)~𝑝𝑒conditionalsubscript𝑓𝜋superscript𝑄𝑖subscript𝑥𝑖\widetilde{pe}(f_{\pi,Q^{i}}|x_{i})over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π , italic_Q start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) as the target in the MSE loss (equation 16). This is optimized with the Adam optimizer and early stopping (using the validation set). The complete set of optimization hyperparameters is provided in Appendix 7.1.

5.3 Routing at inference

At inference, when the sample x𝑥xitalic_x reaches g2superscript𝑔2g^{2}italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, the probability of future error pe^θ(fπ|x)subscript^𝑝𝑒𝜃conditionalsubscript𝑓superscript𝜋𝑥\widehat{pe}_{\theta}(f_{\pi^{\prime}}|x)over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) for every path π𝒫π[1]superscript𝜋subscript𝒫𝜋delimited-[]1\pi^{\prime}\in\mathcal{P}_{\pi[1]}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π [ 1 ] end_POSTSUBSCRIPT is computed and x𝑥xitalic_x is forwarded along the optimal path, as stated in equation 9. The cost cπsubscript𝑐superscript𝜋c_{\pi^{\prime}}italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT of each path in equation 9 is the normalized cost in BitOPS [35, 43, 26], computed as follows. For each layer l𝑙litalic_l in π=b1besuperscript𝜋subscript𝑏1subscript𝑏𝑒\pi^{\prime}=b_{1}\rightarrow\dots\rightarrow b_{e}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → … → italic_b start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT, we compute the layer cost cl=BitOPS(l)=FLOPS(l)×blsubscript𝑐𝑙BitOPS𝑙FLOPS𝑙subscript𝑏𝑙c_{l}=\text{BitOPS}(l)=\text{FLOPS}(l)\times b_{l}italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = BitOPS ( italic_l ) = FLOPS ( italic_l ) × italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. The unnormalized cost of the path is cπ,unnorm=l=1eclsubscript𝑐superscript𝜋unnormsuperscriptsubscript𝑙1𝑒subscript𝑐𝑙c_{\pi^{\prime},\text{unnorm}}=\sum_{l=1}^{e}c_{l}italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , unnorm end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. We normalize all path costs by dividing them by the cost of the most costly path, πmax=bmaxbmaxsubscript𝜋maxsubscript𝑏maxsubscript𝑏max\pi_{\text{max}}=b_{\text{max}}\rightarrow b_{\text{max}}\dotsitalic_π start_POSTSUBSCRIPT max end_POSTSUBSCRIPT = italic_b start_POSTSUBSCRIPT max end_POSTSUBSCRIPT → italic_b start_POSTSUBSCRIPT max end_POSTSUBSCRIPT …, which is incurred when evaluating the entire network using the highest bit-width bmaxsubscript𝑏maxb_{\text{max}}italic_b start_POSTSUBSCRIPT max end_POSTSUBSCRIPT. Therefore:

cπ=cπ,unnormcπmax=l=1eFLOPS(l)×bll=1LFLOPS(l)×bmax.subscript𝑐𝜋subscript𝑐superscript𝜋unnormsubscript𝑐subscript𝜋maxsuperscriptsubscript𝑙1𝑒FLOPS𝑙subscript𝑏𝑙superscriptsubscript𝑙1𝐿FLOPS𝑙subscript𝑏maxc_{\pi}=\frac{c_{\pi^{\prime},\text{unnorm}}}{c_{\pi_{\text{max}}}}=\frac{\sum% _{l=1}^{e}\text{FLOPS}(l)\times b_{l}}{\sum_{l=1}^{L}\text{FLOPS}(l)\times b_{% \text{max}}}\,.italic_c start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT = divide start_ARG italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , unnorm end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT max end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG = divide start_ARG ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT FLOPS ( italic_l ) × italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT FLOPS ( italic_l ) × italic_b start_POSTSUBSCRIPT max end_POSTSUBSCRIPT end_ARG . (17)

At inference, we can obtain different cost-accuracy operating points without retraining QuEE by incorporating a cost-importance hyper-parameter λ+𝜆superscript\lambda\in\mathbb{R}^{+}italic_λ ∈ blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT in equation 9. This allows us to prioritize efficiency or accuracy:

π^adaptive(x)=argminπ𝒫π[:j1]{cπλ+pe^θ(fπ|x)}.superscriptsubscript^𝜋adaptive𝑥subscriptargminsuperscript𝜋subscript𝒫𝜋delimited-[]:absent𝑗1subscript𝑐superscript𝜋𝜆subscript^𝑝𝑒𝜃conditionalsubscript𝑓superscript𝜋𝑥\hat{\pi}_{\text{adaptive}}^{*}(x)=\operatorname*{arg\,min}_{\pi^{\prime}\in% \mathcal{P}_{\pi[:j-1]}}\,\big{\{}c_{\pi^{\prime}}\lambda+\widehat{pe}_{\theta% }(f_{\pi^{\prime}}|x)\big{\}}\,.over^ start_ARG italic_π end_ARG start_POSTSUBSCRIPT adaptive end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π [ : italic_j - 1 ] end_POSTSUBSCRIPT end_POSTSUBSCRIPT { italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_λ + over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) } . (18)

5.4 Baselines

As our proposal combines early exiting and quantization, we include baselines from both settings.

  • DQNet-gate [35] is a data-adaptive mixed-precision quantization architecture. It is a modified version of DQNet [35], where the parameters of the backbone are fixed instead of being trained alongside the bit-selector module. We achieve different accuracy/cost points by changing the target bit-width and cost importance parameter α𝛼\alphaitalic_α.

  • PTQ4ViT [51] is a SOTA PTQ method for ViT models. We report the accuracy obtained a various levels of quantization.

  • JEIDNN[40] is a SOTA EE method for frozen backbones. As a backbone model, we use the backbone quantized at the highest precision bit-width we consider.

  • Thresholding applies the most popular gating decision used in the EE literature, where we exit if the max probability exceeds a certain threshold.

5.5 Results

Refer to captionA)
Refer to captionB)
Refer to captionC)
Refer to captionD)
Refer to captionE)
Refer to captionF)
Figure 3: Accuracy vs BitFLOPS. A) Imagenet t2t-Vit-14, B) Imagenet Vit-14 with DinoV2, C) EE vs quantization performance, D) CIFAR100 t2t-Vit-14, E) CIFAR10 t2t-Vit-7, F) SVHN t2t-Vit-7.

The performance ordering of each baseline is maintained across datasets and backbone architectures, as presented in Figure 3. Overall, QuEE outperforms the baselines, especially for lower-cost regimes, with the exception of the SVHN dataset, which we will discuss later. JEIDNN closely follows, performing particularly well in higher-cost regimes, where it outperforms both QuEE and the initial accuracy of the backbone due to its ability to train the IMs jointly with the gating mechanism. DQNet is less competitive, as it is limited in the range of costs it can reach, since it cannot early exit. It generally outperforms simple quantization (PTQ4ViT), maintaining the same accuracy at a lower cost. We emphasize that both JEIDNN and DQNet require retraining, with adjustment of the trade-off hyperparameters, to obtain each point along the operating curve. This leads to instabilities that are apparent in the results, making these methods less reliable. In practice, this means that in order to obtain a specific target cost or performance, the models would need to be retrained many times until the desired values are attained. In contrast, QuEE only has to be trained once to obtain a full curve, because the cost can be modified at inference to obtain a different trade-off. This makes QuEE more efficient, stable, and flexible, and therefore more practical.

As we stated in our methodology, the strength of QuEE lies in its ability to combine quantization with early exit. For this to be an advantageous strategy, both computation reduction methods need to perform relatively well independently — there is no benefit in combining a weak method with a strong one. As a result, we expect QuEE to perform the best when both quantization and early exit are approximately equally effective. This is confirmed by our experimental results. Figure 3C) compares both computation reduction methods for the various experiments. When the methods are on par with each other, as is the case for CIFAR-10, ImageNet with T2T-ViT, and CIFAR-100, QuEE obtains its best performance. In contrast, for SVHN, the quantization methods are significantly outperformed by early exit, and the early exit method JEIDNN outperforms QuEE on this dataset. SVHN is a relatively simple dataset, and the T2T-ViT-7 architecture is much deeper than necessary for this task, making it particularly suitable for early exit. Conversely, for ImageNet with ViT pre-trained with Dino, we observe the inverse trend, and quantization methods outperform early exit. This indicates that the intermediate representations of ViT are not well suited to be used as input to perform inference, which could be a side effect of the self-supervised objective in DinoV2 [39]. This could also explain the instabilities of JEIDNN, because it jointly trains the inference heads with the gating mechanisms.

Accurately predicting probability of errors leads to better performance for high cost regime
Refer to caption
Figure 4: Accuracy vs cost curves of predictors with varying RMSE performance on the SVHN dataset.

Our algorithm involves prediction of the approximated probabilities of errors pe~(fπ|xi)~𝑝𝑒conditionalsubscript𝑓𝜋subscript𝑥𝑖\widetilde{pe}(f_{\pi}|x_{i})over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), as defined in Eqn 12. We assess the accuracy of these predictions by calculating the root mean square error on test samples rmse=pe~(fπ|xi)hθj22\textrm{rmse}=\|\widetilde{pe}(f_{\pi}|x_{i})-h^{j}_{\theta}\|^{2}_{2}rmse = ∥ over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and analyze how prediction accuracy correlates with performance on the primary task. In Figure 4, we present the accuracy/cost curves of our QuEE algorithm for different hyperparameters on one dataset. As expected, lower RMSE results in better performance. Moreover, the gap in performance is notable in the higher cost regime. This outcome is a direct consequence of the fact that the gating decisions are based on cπ+pe^θ(fπ|x)subscript𝑐𝜋subscript^𝑝𝑒𝜃conditionalsubscript𝑓𝜋𝑥c_{\pi}+\hat{pe}_{\theta}(f_{\pi}|x)italic_c start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT + over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ). In a low-cost regime, the cπsubscript𝑐𝜋c_{\pi}italic_c start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT are given more relative weight, and the pe^θ(fπ|x)subscript^𝑝𝑒𝜃conditionalsubscript𝑓𝜋𝑥\hat{pe}_{\theta}(f_{\pi}|x)over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) have less influence on the gating decision. Consequently, the overall performance is less affected by poor predictions. However, in a high-cost regime, the pe^θ(fπ|x)subscript^𝑝𝑒𝜃conditionalsubscript𝑓𝜋𝑥\hat{pe}_{\theta}(f_{\pi}|x)over^ start_ARG italic_p italic_e end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x ) become more important and we can observe a widening performance gap. QuEE with poor predictors cannot even reach higher cost operating points as it struggles to learn to use the more costly classifiers.

Clustering analysis

In this section, we verify that the choices we made to construct our approximation of the probability of error pe~~𝑝𝑒\tilde{pe}over~ start_ARG italic_p italic_e end_ARG are sensible, and that they lead to a reasonable approximation. First, to validate our approach, we observe the grouping capability of the clusters for certain metrics that are likely related to the ground truth probability of error: the predicted probability of the ground truth class 𝐩^ysubscript^𝐩𝑦\hat{\mathbf{p}}_{y}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT and the entropy of 𝐩^^𝐩\hat{\mathbf{p}}over^ start_ARG bold_p end_ARG. We collect these metrics for test samples with the cluster assigned to each for various K𝐾Kitalic_K. In Figure 5, we can see that some clusters (clusters 0, 1, 2) contain samples that predominantly have low entropy and high predicted 𝐩^ysubscript^𝐩𝑦\hat{\mathbf{p}}_{y}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT, while other clusters (clusters 17, 18, 19) consist almost entirely of samples with relatively high entropy (>0.5) and lower 𝐩^ysubscript^𝐩𝑦\hat{\mathbf{p}}_{y}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT. Next, in Figure 5 B), we compute the calibration error for our assigned pe~~𝑝𝑒\tilde{pe}over~ start_ARG italic_p italic_e end_ARG on test samples. On average, increasing the number of clusters K𝐾Kitalic_K improves the calibration error until we reach K=50𝐾50K=50italic_K = 50. Beyond this, a deteriorating calibration performance is presumably caused by the sampling error. Lastly, in Figure 5 C), we can also observe how K𝐾Kitalic_K affects the efficiency performance. For K=1𝐾1K=1italic_K = 1, the prediction task is very easy -— the modules only need to learn to predict a single fixed value for each path. Therefore, we see a decrease of performance as we increase K𝐾Kitalic_K at first, reaching its lowest at K=5𝐾5K=5italic_K = 5; then the accuracy of the target starts to improve, and we ultimately reach the best performance using larger K𝐾Kitalic_K.

Refer to captionA)
Refer to captionB)
Refer to captionC)
Figure 5: A) Predicted probability of the ground truth label and entropy of a few clusters for K=20𝐾20K=20italic_K = 20 on the SVHN datasets. We can see that there is a correlation between the clustering and those two metrics. B) Expected calibration error of pe~~𝑝𝑒\tilde{pe}over~ start_ARG italic_p italic_e end_ARG generated from varying cluster numbers K𝐾Kitalic_K for the SVHN dataset on T2T-ViT-7 with a fitted second order polynomial. Increasing the number of clusters (up to 50) reduces the calibration error. We discuss this further in Appendix 7.2 C) Accuracy-cost trade-off for different values of K𝐾Kitalic_K for the SVHN dataset on T2T-ViT-7.

6 Conclusions and Limitations

We have introduced a post-training method that can adaptively combine different computation reduction methods such as quantization and early exiting. For future work, we will consider integrating other post training computation reduction methods that were not included in this work, such as pruning or distillation. We will also extend the approach to allow for joint training of the inference heads alongside our gating mechanism, as in [40].

Limitations. The main limitation of our approach is scaling; the algorithm has an inference complexity of O(BEE)𝑂superscript𝐵𝐸𝐸O(B^{E}E)italic_O ( italic_B start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT italic_E ). Although we have found that this is actually not that prohibitive in practice as the relative cost of each gate is extremely small compared to the actual network, such complexity can obviously be a scaling issue. We found that a simple sampling approach can mitigate the issue without significantly impacting the accuracy-cost trade-off. A second limitation of our work is the cost computation. The BitOPS metric that we are using can be viewed as a theoretical metric. The real efficiency benefit in terms of speed and memory remains untested. The practical application of quantized networks requires careful handling and a dedicated kernel implementation. We view this as beyond the scope of this work, but it is clearly an important aspect to tackle for implementation of our proposal in practice. Finally, constructing an approximation of the target PE𝑃𝐸PEitalic_P italic_E for training the decision modules relies on unverifiable assumptions. We offer experimental analysis indicating the reasonableness of our approximation. However, this is not a significant limitation, as it serves merely as an intermediate step toward our main objective, which is measurable: efficiency improvement.

References

  • Bai et al. [2022] H. Bai, L. Hou, L. Shang, X. Jiang, I. King, and M. R. Lyu, “Towards efficient post-training quantization of pre-trained language models,” in Proc. Adv. Neural Info. Process. Syst. (NeurIPS), 2022.
  • Bolukbasi et al. [2017] T. Bolukbasi, J. Wang, O. Dekel, and V. Saligrama, “Adaptive neural networks for efficient inference,” in Proc. Int. Conf. Mach Learn. (ICML), 2017.
  • Bommasani et al. [2022] R. Bommasani, D. A. Hudson, E. Adeli, R. Altman, S. Arora, S. von Arx, M. S. Bernstein, J. Bohg, A. Bosselut, E. Brunskill, E. Brynjolfsson, S. Buch, D. Card, R. Castellon, N. Chatterji, A. Chen, K. Creel, J. Q. Davis, D. Demszky, C. Donahue, M. Doumbouya, E. Durmus, S. Ermon, J. Etchemendy, K. Ethayarajh, L. Fei-Fei, C. Finn, T. Gale, L. Gillespie, K. Goel, N. Goodman, S. Grossman, N. Guha, T. Hashimoto, P. Henderson, J. Hewitt, D. E. Ho, J. Hong, K. Hsu, J. Huang, T. Icard, S. Jain, D. Jurafsky, P. Kalluri, S. Karamcheti, G. Keeling, F. Khani, O. Khattab, P. W. Koh, M. Krass, R. Krishna, R. Kuditipudi, A. Kumar, F. Ladhak, M. Lee, T. Lee, J. Leskovec, I. Levent, X. L. Li, X. Li, T. Ma, A. Malik, C. D. Manning, S. Mirchandani, E. Mitchell, Z. Munyikwa, S. Nair, A. Narayan, D. Narayanan, B. Newman, A. Nie, J. C. Niebles, H. Nilforoshan, J. Nyarko, G. Ogut, L. Orr, I. Papadimitriou, J. S. Park, C. Piech, E. Portelance, C. Potts, A. Raghunathan, R. Reich, H. Ren, F. Rong, Y. Roohani, C. Ruiz, J. Ryan, C. Ré, D. Sadigh, S. Sagawa, K. Santhanam, A. Shih, K. Srinivasan, A. Tamkin, R. Taori, A. W. Thomas, F. Tramèr, R. E. Wang, W. Wang, B. Wu, J. Wu, Y. Wu, S. M. Xie, M. Yasunaga, J. You, M. Zaharia, M. Zhang, T. Zhang, X. Zhang, Y. Zhang, L. Zheng, K. Zhou, and P. Liang, “On the opportunities and risks of foundation models,” 2022, arXiv preprint: arXiv 2108.07258.
  • Brown et al. [2020] T. Brown, B. Mann, N. Ryder, M. Subbiah, J. D. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, S. Agarwal, A. Herbert-Voss, G. Krueger, T. Henighan, R. Child, A. Ramesh, D. Ziegler, J. Wu, C. Winter, C. Hesse, M. Chen, E. Sigler, M. Litwin, S. Gray, B. Chess, J. Clark, C. Berner, S. McCandlish, A. Radford, I. Sutskever, and D. Amodei, “Language models are few-shot learners,” in Proc. Adv. Neural Info. Process. Syst. (NeurIPS), 2020.
  • Dai et al. [2020] X. Dai, X. Kong, and T. Guo, “EPNet: Learning to exit with flexible multi-branch network,” in Proc. Int. Conf. on Inf. & Knowl. Manage. (CIKM), 2020.
  • Deng et al. [2009] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei, “Imagenet: A large-scale hierarchical image database,” in Proc. IEEE/CVF Conf. on Comput. Vision and Pattern Recognit. (CVPR), 2009.
  • Dettmers et al. [2022] T. Dettmers, M. Lewis, Y. Belkada, and L. Zettlemoyer, “Llm.int8(): 8-bit matrix multiplication for transformers at scale,” in Proc. Adv. Neural Info. Process. Syst. (NeurIPS), 2022.
  • Devlin et al. [2019] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, “BERT: Pre-training of deep bidirectional transformers for language understanding,” in Proc. Conf. of the North Amer. Chapter of the Assoc. for Comput. Linguistics: Human Lang. Technologies, 2019.
  • Dosovitskiy et al. [2021] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, J. Uszkoreit, and N. Houlsby, “An image is worth 16x16 words: Transformers for image recognition at scale,” in Proc. Int. Conf. on Learn. Representations (ICLR), 2021.
  • Frantar and Alistarh [2023] E. Frantar and D. Alistarh, “SparseGPT: Massive language models can be accurately pruned in one-shot,” in Proc. Int. Conf. Mach Learn. (ICML), 2023.
  • Gholami et al. [2021] A. Gholami, S. Kim, Z. Dong, Z. Yao, M. W. Mahoney, and K. Keutzer, “A survey of quantization methods for efficient neural network inference,” 2021, arXiv preprint: arXiv 2103.13630.
  • Gomes et al. [2024] E. D. C. Gomes, M. Romanelli, G. Pichler, and P. Piantanida, “A data-driven measure of relative uncertainty for misclassification detection,” in Proc. Int. Conf. on Learn. Representations (ICLR), 2024.
  • Gou et al. [2021] J. Gou, B. Yu, S. J. Maybank, and D. Tao, “Knowledge distillation: A survey,” Int. J. Comput. Vision, vol. 129, no. 6, p. 1789–1819, jun 2021.
  • Granese et al. [2021] F. Granese, M. Romanelli, D. Gorla, C. Palamidessi, and P. Piantanida, “DOCTOR: A simple method for detecting misclassification errors,” in Proc. Adv. Neural Info. Process. Syst. (NeurIPS), 2021.
  • Guo et al. [2017] C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger, “On calibration of modern neural networks,” in Proc. Int. Conf. Machine Learning. (ICML), 2017.
  • Han et al. [2022a] Y. Han, G. Huang, S. Song, L. Yang, H. Wang, and Y. Wang, “Dynamic neural networks: A survey,” IEEE Trans. on Pattern Analysis.; Mach. Intell., vol. 44, no. 11, pp. 7436–7456, Nov. 2022.
  • Han et al. [2022b] Y. Han, Y. Pu, Z. Lai, C. Wang, S. Song, J. Cao, W. Huang, C. Deng, and G. Huang, “Learning to weight samples for dynamic early-exiting networks,” in Proc. European Conf. on Computer Vision (ECCV), 2022.
  • Han et al. [2023] Y. Han, D. Han, Z. Liu, Y. Wang, X. Pan, Y. Pu, C. Deng, J. Feng, S. Song, and G. Huang, “Dynamic perceiver for efficient visual recognition,” in Proc. IEEE Int. Conf. on Comput. Vision (ICCV), 2023.
  • Hazimeh et al. [2020] H. Hazimeh, N. Ponomareva, P. Mol, Z. Tan, and R. Mazumder, “The tree ensemble layer: Differentiability meets conditional computation,” in Proc. Int. Conf. Mach Learn. (ICML), 2020.
  • Herrmann et al. [2020] C. Herrmann, R. S. Bowen, and R. Zabih, “Channel selection using gumbel softmax,” in Proc. Eur Conf. Comput. Vision (ECCV), 2020.
  • Hinton et al. [2015] G. Hinton, O. Vinyals, and J. Dean, “Distilling the knowledge in a neural network,” in Proc. Adv. Neural Info. Process. Syst. (NeurIPS), 2015.
  • Hong et al. [2022] C. Hong, S. Baik, H. Kim, S. Nah, and K. M. Lee, “CADyQ: Content-aware dynamic quantization for image super-resolution,” in Proc. Eur. Conf. Comput. Vision (ECCV), 2022.
  • Huang et al. [2018a] G. Huang, D. Chen, T. Li, F. Wu, L. van der Maaten, and K. Weinberger, “Multi-scale dense networks for resource efficient image classification,” in Proc. Int. Conf. on Learning Representations (ICLR), 2018.
  • Huang et al. [2018b] G. Huang, S. Liu, L. van der Maaten, and K. Q. Weinberger, “Condensenet: An efficient densenet using learned group convolutions,” in Proc. Conf. on Comput. Vision and Pattern Recognit. (CVPR), 2018.
  • Ilhan et al. [2024] F. Ilhan, L. Liu, K.-H. Chow, W. Wei, Y. Wu, M. Lee, R. R. Kompella, H. Latapie, and G. Liu, “Adaptive deep neural network inference optimization with EENet,” in Proc. Eur. Conf. Comput. Vision (ECCV), 2024.
  • Jin et al. [2020] Q. Jin, L. Yang, and Z. Liao, “Adabits: Neural network quantization with adaptive bit-widths,” in Proc. Conf. on Comput. Vision and Pattern Recognit. (CVPR), 2020.
  • Jitkrittum et al. [2023] W. Jitkrittum, N. Gupta, A. K. Menon, H. Narasimhan, A. S. Rawat, and S. Kumar, “When does confidence-based cascade deferral suffice?” in Proc. Adv. Neural Info. Process. Syst. (NeurIPS), 2023.
  • Khalilian Gourtani and Meratnia [2023] S. Khalilian Gourtani and N. Meratnia, “Escepe: Early-exit network section-wise model compression using self-distillation and weight clustering,” in Proc. Int. Workshop Edge Syst., Anal. and Network., ser. EdgeSys ’23, 2023, p. 48–53.
  • Khan et al. [2022] S. Khan, M. Naseer, M. Hayat, S. W. Zamir, F. S. Khan, and M. Shah, “Transformers in vision: A survey,” ACM Comput. Surv., vol. 54, no. 10s, 2022.
  • Kingma and Ba [2017] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” in Proc. Int. Conf. on Learn. Representations (ICLR), 2017.
  • Krizhevsky [2009] A. Krizhevsky, “Learning multiple layers of features from tiny images,” University of Toronto, Tech. Rep., 2009.
  • Kuzmin et al. [2023] A. Kuzmin, M. Nagel, M. V. Baalen, A. Behboodi, and T. Blankevoort, “Pruning vs quantization: Which is better?” in Proc. Adv. Neural Info. Process. Syst. (NeurIPS), 2023.
  • Li et al. [2024] Y. Li, Y. Yu, C. Liang, P. He, N. Karampatziakis, W. Chen, and T. Zhao, “Loftq: Lora-fine-tuning-aware quantization for large language models,” in Proc. Int. Conf. on Learn. Representations (ICLR), 2024.
  • Liu and Deng [2018] L. Liu and J. Deng, “Dynamic deep neural networks: Optimizing accuracy-efficiency trade-offs by selective execution,” in Proc. AAAI Conf. on Artif. Intell., 2018.
  • Liu et al. [2022] Z. Liu, Y. Wang, K. Han, S. Ma, and W. Gao, “Instance-aware dynamic neural network quantization,” in Proc. Conf. on Comput. Vision and Pattern Recognit. (CVPR), 2022.
  • Marinó et al. [2023] G. C. Marinó, A. Petrini, D. Malchiodi, and M. Frasca, “Deep neural networks compression: A comparative survey and choice recommendations,” Neurocomput., vol. 520, p. 152–170, Feb. 2023.
  • Netzer et al. [2011] Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng, in Proc. NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011, 2011.
  • Odena et al. [2017] A. Odena, D. Lawson, and C. Olah, “Changing model behavior at test-time using reinforcement learning,” in Proc. Int. Conf. on Learn. Representations Workshop, 2017.
  • Oquab et al. [2024] M. Oquab, T. Darcet, T. Moutakanni, H. Vo, M. Szafraniec, V. Khalidov, P. Fernandez, D. Haziza, F. Massa, A. El-Nouby, M. Assran, N. Ballas, W. Galuba, R. Howes, P.-Y. Huang, S.-W. Li, I. Misra, M. Rabbat, V. Sharma, G. Synnaeve, H. Xu, H. Jegou, J. Mairal, P. Labatut, A. Joulin, and P. Bojanowski, “DINOv2: Learning robust visual features without supervision,” Trans. on Mach. Learn. Res. (TMLR), Jan. 2024.
  • Regol et al. [2024] F. Regol, J. Chataoui, and M. Coates, “Jointly-learned exit and inference for a dynamic neural network : JEI-DNN,” in Proc. Int. Conf. on Learn. Representations (ICLR), 2024.
  • Rokh et al. [2023] B. Rokh, A. Azarpeyvand, and A. Khanteymoori, “A comprehensive survey on model quantization for deep neural networks in image classification,” ACM Trans. Intell. Syst. Technol., vol. 14, no. 6, nov 2023.
  • Saxena and Roy [2023] U. Saxena and K. Roy, “McQueen: Mixed precision quantization of early exit networks,” in Proc. Brit. Mach. Conf. (BMVC), 2023.
  • Tian et al. [2023] S. Tian, M. Lu, J. Liu, Y. Guo, Y. Chen, and S. Zhang, “CABM: Content-aware bit mapping for single image super-resolution network with large input,” in Proc. Conf. on Comput. Vision and Pattern Recognit. (CVPR), 2023.
  • Touvron et al. [2023] H. Touvron, L. Martin, K. Stone, P. Albert, A. Almahairi, Y. Babaei, N. Bashlykov, S. Batra, P. Bhargava, S. Bhosale, D. Bikel, L. Blecher, C. C. Ferrer, M. Chen, G. Cucurull, D. Esiobu, J. Fernandes, J. Fu, W. Fu, B. Fuller, C. Gao, V. Goswami, N. Goyal, A. Hartshorn, S. Hosseini, R. Hou, H. Inan, M. Kardas, V. Kerkez, M. Khabsa, I. Kloumann, A. Korenev, P. S. Koura, M.-A. Lachaux, T. Lavril, J. Lee, D. Liskovich, Y. Lu, Y. Mao, X. Martinet, T. Mihaylov, P. Mishra, I. Molybog, Y. Nie, A. Poulton, J. Reizenstein, R. Rungta, K. Saladi, A. Schelten, R. Silva, E. M. Smith, R. Subramanian, X. E. Tan, B. Tang, R. Taylor, A. Williams, J. X. Kuan, P. Xu, Z. Yan, I. Zarov, Y. Zhang, A. Fan, M. Kambadur, S. Narang, A. Rodriguez, R. Stojnic, S. Edunov, and T. Scialom, “Llama 2: Open foundation and fine-tuned chat models,” 2023, arXiv:2307.09288.
  • Verma et al. [2023] R. Verma, D. Barrejon, and E. Nalisnick, “Learning to defer to multiple experts: Consistent surrogate losses, confidence calibration, and conformal ensembles,” in Proc. Int. Conf. on Artif. Intell. and Statist. (AISTAT), 2023.
  • Wang et al. [2024] X. Wang, J. Rachwan, S. Günnemann, and B. Charpentier, “Structurally prune anything: Any architecture, any framework, any time,” 2024, arXiv:2403.18955.
  • Wang et al. [2020] Y. Wang, J. Shen, T.-K. Hu, P. Xu, T. Nguyen, R. Baraniuk, Z. Wang, and Y. Lin, “Dual dynamic inference: Enabling more efficient, adaptive, and controllable deep inference,” IEEE J. of Sel. Topics in Signal Process., vol. 14, no. 4, pp. 623–633, 2020.
  • Xia et al. [2022] W. Xia, H. Yin, X. Dai, and N. Jha, “Fully dynamic inference with deep neural networks,” IEEE Trans. on Emerg. Topics in Comput., vol. 10, no. 2, pp. 962–972, 2022.
  • Yu et al. [2022] H. Yu, H. Li, G. Hua, G. Huang, and H. Shi, “Boosted dynamic neural networks,” in Proc. AAAI Conf. on Artif. Intell., 2022.
  • Yuan et al. [2021] L. Yuan, Y. Chen, T. Wang, W. Yu, Y. Shi, Z. Jiang, F. E. Tay, J. Feng, and S. Yan, “Tokens-to-token ViT: Training vision transformers from scratch on imagenet,” in Proc, IEEE Int. Conf. on Computer Vision (ICCV), 2021.
  • Yuan et al. [2022] Z. Yuan, C. Xue, Y. Chen, Q. Wu, and G. Sun, “PTQ4ViT: Post-training quantization for vision transformers with twin uniform quantization,” in Proc. Eur Conf. Comput. Vision (ECCV), 2022.

7 Appendix

Table 1: Notation table
symbol description
𝒳𝒳\mathcal{X}caligraphic_X Input space.
𝒴={1,2,,|𝒴|}𝒴12𝒴\mathcal{Y}=\{1,2,\dots,|\mathcal{Y}|\}caligraphic_Y = { 1 , 2 , … , | caligraphic_Y | } Label space.
P(Y|X)𝑃conditional𝑌𝑋P(Y|X)italic_P ( italic_Y | italic_X ) ground truth conditional distribution.
PE(f(x)|x)𝑃𝐸conditional𝑓𝑥𝑥PE(f(x)|x)italic_P italic_E ( italic_f ( italic_x ) | italic_x ) Probability of making an error with prediction y𝑦yitalic_y given x𝑥xitalic_x. equal to 1Pr(Y=y|X)1Pr𝑌conditional𝑦𝑋1-\Pr(Y=y|X)1 - roman_Pr ( italic_Y = italic_y | italic_X )
General dynamic network notation
flsubscript𝑓𝑙f_{l}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT The l𝑙litalic_l-th classifier.
clsubscript𝑐𝑙c_{l}italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT The cost associated to the l𝑙litalic_l-th classifier.
S(x)[L]𝑆𝑥delimited-[]𝐿S(x)\in[L]italic_S ( italic_x ) ∈ [ italic_L ] A classifier-selector function (with negligible evaluation cost compared to clsubscript𝑐𝑙c_{l}italic_c start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT).
QuEE notation
E𝐸Eitalic_E Number of early exits in the network.
B𝐵Bitalic_B Number of levels of computation (quantization) that can be used per layer.
L=e=1EBe𝐿subscriptsuperscript𝐸𝑒1superscript𝐵𝑒L=\sum^{E}_{e=1}B^{e}italic_L = ∑ start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_e = 1 end_POSTSUBSCRIPT italic_B start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT Total number of classifiers/paths that we can select.
π=b1be𝜋subscript𝑏1subscript𝑏𝑒\pi=b_{1}\to\dots\to b_{e}italic_π = italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → … → italic_b start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT A path. A path contains the level of computation used by each layer.
𝒫𝒫\mathcal{P}caligraphic_P The set of all paths. |𝒫|=L𝒫𝐿|\mathcal{P}|=L| caligraphic_P | = italic_L.
𝒫πsubscript𝒫𝜋\mathcal{P}_{\pi}caligraphic_P start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT The set of all paths starting with π𝜋\piitalic_π.
fπ:𝒳𝒴:subscript𝑓𝜋𝒳𝒴f_{\pi}:\mathcal{X}\to\mathcal{Y}italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT : caligraphic_X → caligraphic_Y The classifier corresponding to π𝜋\piitalic_π.
cπ[0,1]subscript𝑐𝜋01c_{\pi}\in[0,1]italic_c start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ∈ [ 0 , 1 ] The cost of evaluating fπsubscript𝑓𝜋f_{\pi}italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT.
𝐩^π[0,1]|𝒴|subscript^𝐩𝜋superscript01𝒴\hat{\mathbf{p}}_{\pi}\in[0,1]^{|\mathcal{Y}|}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT | caligraphic_Y | end_POSTSUPERSCRIPT The predicted prob. vector of fπsubscript𝑓𝜋f_{\pi}italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT.
gθe()subscriptsuperscript𝑔𝑒𝜃g^{e}_{\theta}(\cdot)italic_g start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) Gate placed at exit e𝑒eitalic_e. Decides on whether we should exit or use computation level b𝑏bitalic_b.

7.1 Experiment details

7.1.1 Backbones

In our experiments we used the T2T-ViT [50] and ViT [9] vision transformers pre-trained with DinoV2 [39]. We provide now more details about these two architectures.

ViT with DinoV2

The vision transformer proposed by Dosovitskiy et al. splits an input image into patches which are projected as embeddings and combined with positional encoding [9]. The architecture also uses a special “classification” token that is fed as input to a classification head at the last layer. The embedding size is set as a hyper-paramter. We use ViT-S-14 throughout our experiments. S (“small”) corresponds to an embedding size of 384 while 14 is the patch size. We use a 12-layer ViT in all our experiments. The ViT backbones were all pre-trained using the DinoV2 procedure [39] which uses self-supervision on a variety of large datasets to train a foundation model that can be used on downstream vision tasks with minimal fine-tuning [39]. Specifically, the backbone of the foundation model does not need to be retrained and only the inference head needs to be adapted to the dataset and task at hand [39].

T2T-ViT

The token-to-token vision transformer is a more data-efficient vision transformer architecture [50]. T2T-ViT uses a progressive tokenization scheme to fuse neighbouring embeddings into a single embedding at the input layer, capturing local structure and reducing the token size. This greatly reduces the number of trainable parameter, making the model more data-efficient [50]. In our experiments we used the 7-layer and 14-layer T2T-ViT models. Yuan et al. provide the parameters for these 2 architectures pre-trained on ImageNet [6] which we used as-is for ImageNet experiments on T2T-ViT. For other datasets, we use the transfer-learning procedure also provided by Yuan et al. to fine-tune the backbone.

7.1.2 Baselines

In this section we discuss the different baselines we use in greater depth.

PTQ4ViT

PTQ4ViT is a SOTA post-training quantization algorithm specialized for vision transformers [51]. Yuan et al. observe that the activations in vision transformers do not exhibit the usual Gaussian distributions reported in CNN-based models. For example, the values after the GeLU cover a very large positive range and a restricted negative range. To address this, they propose a twin-uniform quantization scheme where uniform quantization is applied after splitting the activations in two disjoint regions (positive and negative). This allows to find optimal quantization parameters for each region separately [51]. The proposed algorithm uses the Hessian of the loss with respect to the parameters as an indicator of sensitivity to determine the optimal quantization parameters [51]. We obtain the different operating points by quantizating the entire network with PTQ4ViT at different bit resolutions.

DQNet

DQNet proposes a data-adaptive mixed-precision quantization scheme for image classification [35]. A lightweight bit-controller module is attached to the network to be quantized. The bit-controller’s role is to determine the optimal bit width of each layer for a given sample [35]. It is placed at a user-specified layer and uses the feature map output by that layer to determine the optimal bit path. The DQNet data-adaptive quantization procedure can be applied to any architecture and is compatible with other quantization schemes [35] making it a perfect candidate for our setup. For fairer comparison, we integrate DQNet with PTQ4ViT[51] which provides better quantization results on vision transformers. As such, we slightly modify the training procedure to work in a post-training way, in line with our pre-trained model setting. The training procedure used first quantizes the trained network using PTQ4ViT and then trains the bit-controller separately. The algorithm uses a cost-aware loss which allows the user to emphasize efficiency or accuracy. This is done by introducing a bit-loss term bit=l=1Lblbtargetsubscriptbitsuperscriptsubscript𝑙1𝐿subscript𝑏𝑙subscript𝑏target\mathcal{L}_{\text{bit}}=\sum_{l=1}^{L}b_{l}-b_{\text{target}}caligraphic_L start_POSTSUBSCRIPT bit end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT target end_POSTSUBSCRIPT which measures the distance of the selected bit widths to a user-specified target bit [35]. The overall DQNet loss is =CE+αbitsubscriptCE𝛼subscriptbit\mathcal{L}=\mathcal{L}_{\text{CE}}+\alpha\mathcal{L}_{\text{bit}}caligraphic_L = caligraphic_L start_POSTSUBSCRIPT CE end_POSTSUBSCRIPT + italic_α caligraphic_L start_POSTSUBSCRIPT bit end_POSTSUBSCRIPT where α𝛼\alphaitalic_α is a hyper-parameter. We obtain various operating points by varying the target bit width as well as α𝛼\alphaitalic_α.

JEIDNN

JEIDNN is an early exit framework that is compatible with frozen backbones. Regol et al. augment a pre-trained backbone network with classifiers and controllers. The role of the controllers is to determine whether a sample can be exited or should be propagated along the network for further processing. The controllers and classifiers are jointly trained to allow for the “specialization” of the classifiers on points that they will effectively handle at inference time Regol et al. [40]. JEIDNN also uses a cost-aware loss, controlled by a hyper-parameter λ𝜆\lambdaitalic_λ which encodes the cost-accuracy trade-off. We obtain different operating points in our experiments by running JEIDNN with varying λ𝜆\lambdaitalic_λ values.

7.1.3 Datasets

ImageNet

consists of 1.2 million training images representing 1,000 balanced classes [6]. The images are resized to 256×256256256256\times 256256 × 256 pixels. We use 1.1 million images for training as well as 50,000 images as a validation set and 50,000 images for testing. The validation set is used for early-stopping as well as hyper-parameter grid-search. We also use 600 images for generating the K-means clusters use to discretize the predictors hθjsuperscriptsubscript𝜃𝑗h_{\theta}^{j}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT training targets.

CIFAR

[31] datasets consist of 60,000 small 32×32323232\times 3232 × 32 coloured images. CIFAR10 consists of 10 classes while CIFAR100 spans 100 classes. CIFAR10 images were resized to 70×70707070\times 7070 × 70 pixels when used with ViT-DinoV2 [9, 39] models (the image size needs to be a multiple of the patch size of 14) and 64×64646464\times 6464 × 64 for T2T-ViT models [50]. CIFAR100 images were resized to 224×224224224224\times 224224 × 224 for T2T-ViT following Yuan et al. [50] and 70×70707070\times 7070 × 70 for Vit-DinoV2. We reserve 5,000 images for the validation set (hyper-parameter search, early-stopping and discretization of targets) and 10,000 images for testing.

SVHN

SVHN [37] dataset consists of 73,257 training images and 26,032 testing images. We reserve 5,000 images from the training set as validation data. The images are 32×32323232\times 3232 × 32 coloured images of house numbers. The version of SVHN we use spans 10 classes where the label corresponds to the central digit of the house number while the remaining digits are noise.

The results reported correspond to 95%percent\%% confidence intervals on the means of our results. These intervals are obtained by using a boostrap procedure of 10 trials where the test set is split into 10 subsets.

7.1.4 Experiment parameters

We share experiment parameters in this section, describing each step of the experiment in a paragraph. Table 2 summarizes shared hyper-parameters for all algorithms (QuEE and baselines). The highest bit resolution was chosen as the smallest bit width that maintains the overall network accuracy. The lower resolutions were chosen so that the performance drop of the last prediction head never exceeds 10%.

Intermediate classifiers training

For algorithms that use early-exiting, we augment the pre-trained frozen network with intermediate inference modules (IMs) are user-specified layers. This step is skipped for quantization-only algorithms (DQNet [35], PTQ4ViT[51]) Table 2 shows the exits used for each dataset-architecture pair. These IMs are trained jointly until convergence using a simple scaled loss as was done in Regol et al. [40]. Table 2 indicates the maximum number of IM training which we use in conjunction with early-stopping. In practice SVHN converges at around 7 epochs, CIFAR10 at 10, CIFAR100 at 13 and ImageNet at around 15. We use the ADAM optimizer [30] with an initial learning rate of 1e31superscript𝑒31e^{-3}1 italic_e start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and a weight-decay of 1e41superscript𝑒41e^{-4}1 italic_e start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT. We note that the backbone parameters are frozen as these backbones have been trained extensively on large datasets [39, 33]. This is well-aligned with the setup of working with foundation models where only the inference head needs to be optimized on a specific dataset.

Quantization

We quantize the pre-trained backbones using PTQ4ViT [51]. As a post-training quantization algorithm, it requires a small amount of data for calibration. We use 128 calibration samples for all datasets and all models. This is sufficient as Yuan et al. [51] achieve SOTA quantization using only 32 images. The bit resolution per dataset are summarized in Table 2. For early-exiting baselines (JEIDNN and EE with thresholding), we use the largest bit-width for each dataset (8 for all datasets but SVHN where we used 7). This is to ensure a fair comparison.

Algorithm-specific information
  • QuEE Starting from a multi-quantization network with trained IMs we train the QuEE predictors on the discretized probability of error (obtained via K-means clustering on the predicted probability 𝐩^πsubscript^𝐩𝜋\hat{\mathbf{p}}_{\pi}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT for each path π𝜋\piitalic_π). Table 3 summarized QuEE-specific hyper-parameters. The number of clusters was chosen based on the ECE using pe~~𝑝𝑒\widetilde{pe}over~ start_ARG italic_p italic_e end_ARG computed on the validation set computed on 600 samples from the validation set. We used SciKit learn’s K-Means algorithm using the L2 norm as a distance metric where the clusters centroids are randomly initialized 10 times and the run with the lowest inertia is kept. We then train the gates using the ADAM optimizer [30] with an initial learning rate of 1e31superscript𝑒31e^{-3}1 italic_e start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and a weight-decay of 1e41superscript𝑒41e^{-4}1 italic_e start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT.

  • DQNet As mentioned earlier, we use PTQ4ViT[51] as a quantization algorithm on the DQNet [35] architecture since it is compatible with any quantization algorithm. We place the bit-controller at the earliest exit as indicated in Table 2. We iterate over the bit resolutions listed in Table 2 while sweeping values of α{0,0.01,0.05,0.1,0.3,0.5,2,10}𝛼00.010.050.10.30.5210\alpha\in\{0,0.01,0.05,0.1,0.3,0.5,2,10\}italic_α ∈ { 0 , 0.01 , 0.05 , 0.1 , 0.3 , 0.5 , 2 , 10 }. We use the same number of training epochs as QuEE (see Table 3). We also use the ADAM optimizer [30] with an initial learning rate of 1e31superscript𝑒31e^{-3}1 italic_e start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and a weight-decay of 1e41superscript𝑒41e^{-4}1 italic_e start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT.

  • PTQ4ViT We obtain different operating points for the accuracy-cost curves using PTQ4Vit [51] by setting the bit resolution to each bit resolution listed in Table 2.

  • JEIDNN As a base JEIDNN network we use the highest bit resolution listed in Table 2 for each dataset. We augment the network with gating modules following the structure of the original algorithm [40]. We use a bi-level batch count of 200 unless the training data loader had less than 200 batches in which case we set the bi-level batch count to half the number of batches in the data loader. We obtain different operating points by sweeping the cost-accuracy trade-off hyper-parameter λ{0,0.01,0.05,0.1,0.3,0.5,0.8,1,1.5,2,2.5,3,5}𝜆00.010.050.10.30.50.811.522.535\lambda\in\{0,0.01,0.05,0.1,0.3,0.5,0.8,1,1.5,2,2.5,3,5\}italic_λ ∈ { 0 , 0.01 , 0.05 , 0.1 , 0.3 , 0.5 , 0.8 , 1 , 1.5 , 2 , 2.5 , 3 , 5 }. We also use the ADAM optimizer [30] with an initial learning rate of 1e31superscript𝑒31e^{-3}1 italic_e start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and a weight-decay of 1e41superscript𝑒41e^{-4}1 italic_e start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT.

Table 2: Hyper-parameters for IM pre-training and quantization for all algorithms
Dataset Arch
Input
size
IM
train.
max. epoch
Bit
res.
Batch
size
Exits
SVHN T2T-ViT-7 32x32 10 5,6,7 512 2,4,7
CIFAR10 T2T-ViT-7 64x64 15 5,6,7,8 512 2,4,7
CIFAR100 T2T-ViT-14 224x224 20 6,8 128 5,7,11,14
ImageNet T2T-ViT-14 224x224 20 6,8 128 5,7,11,14
ImageNet ViTs14 256x256 20 6,7,8 128
4,6,8,
10,12
Table 3: Hyper-parameters for QuEE
Dataset Arch
Num
clusters
(K)
Batch
size
QuEE
training
epochs
Emb.
size
Hid.
dim.
size
Avg. training time
NVIDIA GeForce
GTX 2070
SVHN T2T-ViT-7 70 512 10 8 16 5 min.
CIFAR10 T2T-ViT-7 20 512 10 8 16 13 min.
CIFAR100 T2T-ViT-14 70 128 20 8 32 1.6 hour
ImageNet ViTs14 40 128 20 8 32 4.5 hour

7.2 Additional clustering discussion

As mentioned in Section 4.1, we assume that the probability of error PE(fπ|xi)𝑃𝐸conditionalsubscript𝑓𝜋subscript𝑥𝑖PE(f_{\pi}|x_{i})italic_P italic_E ( italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is smooth over a transformation of the predicted probability space 𝐩^π(x)subscript^𝐩𝜋𝑥\hat{\mathbf{p}}_{\pi}(x)over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_x ) for a given path π𝜋\piitalic_π. We thus discretize the space into K partitions denoted {Q1,,QK}subscript𝑄1subscript𝑄𝐾\{Q_{1},\dots,Q_{K}\}{ italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_Q start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT } and compute a “delegate” value pe~(fπ,Q)~𝑝𝑒subscript𝑓𝜋𝑄\widetilde{pe}(f_{\pi,Q})over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π , italic_Q end_POSTSUBSCRIPT ) for each cluster Q. During training of the predictor modules when

We now provide additional experimental results for the clustering of target probabilities of error discussed in in Section 4.1.

Clustering effectively separates predictions along entropy and probability of ground truth

We demonstrate that running K-means with large enough values of K effectively separates predictions across clusters corresponding to high/low entropies as well as clusters corresponding to high/low predicted probability of ground-truth. Figure 6 shows entropy box-plots on SVHN using T2T-ViT-7 with K = 6 and K = 20. We can see that while the clusters with K = 6 have roughly the same average entropy, with K = 20 we start seeing more distinct clusters. A similar pattern can be seen when looking at the predicted probability of the ground-truth in figure 7. We show entropy per cluster on CIFAR100 on T2T-ViT-14 with K = 40 in figure 8 and the ECE as a function of K in figure 9.

Refer to caption
Refer to caption
Figure 6: Box-plots showing the entropy of each cluster computed for SVHN on T2T-ViT-7 at path π=766𝜋766\pi=7\rightarrow 6\rightarrow 6italic_π = 7 → 6 → 6 with K=6 (left) and K=20 (right).
Refer to caption
Refer to caption
Figure 7: Box-plots showing the predicted probabilty at the ground-truth class of each cluster computed for SVHN on T2T-ViT-7 at path π=766𝜋766\pi=7\rightarrow 6\rightarrow 6italic_π = 7 → 6 → 6 with K=6 (left) and K=20 (right).
Refer to caption
Figure 8: Box-plots showing the entropy of each cluster computed for CIFAR100 on T2T-ViT-14 at path π=87𝜋87\pi=8\rightarrow 7italic_π = 8 → 7 with K=40
Refer to caption
Figure 9: ECE of pe~~𝑝𝑒\tilde{pe}over~ start_ARG italic_p italic_e end_ARG generated from varying cluster numbers K𝐾Kitalic_K for the CIFAR100 dataset on T2T-ViT-14 with a fitted second order polynomial.
ECE using pe~~𝑝𝑒\widetilde{pe}over~ start_ARG italic_p italic_e end_ARG

In figure 5 we showed the expected calibration error computed on a separate test set when using pe~~𝑝𝑒\widetilde{pe}over~ start_ARG italic_p italic_e end_ARG as a function of K. pe~~𝑝𝑒\widetilde{pe}over~ start_ARG italic_p italic_e end_ARG corresponds to the “delegate” value of a a given cluster at a path π𝜋\piitalic_π as defined in equation 12.

To compute the ECE over the test set 𝒮testsubscript𝒮test\mathcal{S}_{\text{test}}caligraphic_S start_POSTSUBSCRIPT test end_POSTSUBSCRIPT, we first compute the ECE at each path π𝜋\piitalic_π ECEπsubscriptECE𝜋\text{ECE}_{\pi}ECE start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT as follows:

  1. 1.

    Obtain the prediction 𝐩^π(x)subscript^𝐩𝜋𝑥\hat{\mathbf{p}}_{\pi}(x)over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_x ) for each sample (x,y)𝒮test𝑥𝑦subscript𝒮test(x,y)\in\mathcal{S}_{\text{test}}( italic_x , italic_y ) ∈ caligraphic_S start_POSTSUBSCRIPT test end_POSTSUBSCRIPT at classifier fπsubscript𝑓𝜋f_{\pi}italic_f start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT.

  2. 2.

    Assign x to the appropriate cluster Qsuperscript𝑄Q^{*}italic_Q start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

  3. 3.

    Approximate the probability at y𝑦yitalic_y as p~(y|x)=1pe~(fπ,Q|xi)~𝑝conditional𝑦𝑥1~𝑝𝑒conditionalsubscript𝑓𝜋superscript𝑄subscript𝑥𝑖\tilde{p}(y|x)=1-\widetilde{pe}(f_{\pi,Q^{*}}|x_{i})over~ start_ARG italic_p end_ARG ( italic_y | italic_x ) = 1 - over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π , italic_Q start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ).

  4. 4.

    Following Guo et al. [15], we sort all p~(y|x)~𝑝conditional𝑦𝑥\tilde{p}(y|x)over~ start_ARG italic_p end_ARG ( italic_y | italic_x ) in ascending orders and split them into 15 bins denoted B1,B2B15subscript𝐵1subscript𝐵2subscript𝐵15B_{1},B_{2}\dots B_{15}italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT … italic_B start_POSTSUBSCRIPT 15 end_POSTSUBSCRIPT

  5. 5.

    Compute the average predicted probability using p~(y|x)~𝑝conditional𝑦𝑥\tilde{p}(y|x)over~ start_ARG italic_p end_ARG ( italic_y | italic_x ) at each bin B as:

    P¯(B)=1|B|iBp~(yi|xi)¯𝑃𝐵1𝐵subscript𝑖𝐵~𝑝conditionalsubscript𝑦𝑖subscript𝑥𝑖\bar{P}(B)=\frac{1}{|B|}\sum_{i\in B}\tilde{p}(y_{i}|x_{i})over¯ start_ARG italic_P end_ARG ( italic_B ) = divide start_ARG 1 end_ARG start_ARG | italic_B | end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_B end_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
  6. 6.

    Compute the accuracy at each bin B: ACC(B)=1|B|iB𝟙[y^i=yi]𝐴𝐶𝐶𝐵1𝐵subscript𝑖𝐵1delimited-[]subscript^𝑦𝑖subscript𝑦𝑖ACC(B)=\frac{1}{|B|}\sum_{i\in B}\mathbbm{1}[\hat{y}_{i}=y_{i}]italic_A italic_C italic_C ( italic_B ) = divide start_ARG 1 end_ARG start_ARG | italic_B | end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_B end_POSTSUBSCRIPT blackboard_1 [ over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ]

  7. 7.

    Compute the calibration error at each bin B Δ(B)Δ𝐵\Delta(B)roman_Δ ( italic_B ) as the difference between the average predicted probability and the accuracy: Δ(B)=|ACC(B)P¯(B)|Δ𝐵𝐴𝐶𝐶𝐵¯𝑃𝐵\Delta(B)=|ACC(B)-\bar{P}(B)|roman_Δ ( italic_B ) = | italic_A italic_C italic_C ( italic_B ) - over¯ start_ARG italic_P end_ARG ( italic_B ) |

  8. 8.

    The ECE for path π𝜋\piitalic_π is then calculated as ECEπ=b=115|Qb|𝒮testΔ(Bb)subscriptECE𝜋superscriptsubscript𝑏115subscript𝑄𝑏subscript𝒮testΔsubscript𝐵𝑏\text{ECE}_{\pi}=\sum_{b=1}^{15}\frac{|Q_{b}|}{\mathcal{S}_{\text{test}}}% \Delta(B_{b})ECE start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 15 end_POSTSUPERSCRIPT divide start_ARG | italic_Q start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT | end_ARG start_ARG caligraphic_S start_POSTSUBSCRIPT test end_POSTSUBSCRIPT end_ARG roman_Δ ( italic_B start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT )

We repeat the above procedure for each path π𝒫𝜋𝒫\pi\in\mathcal{P}italic_π ∈ caligraphic_P. The total ECE for a given K𝐾Kitalic_K is then computed as the average of ECEπsubscriptECE𝜋\text{ECE}_{\pi}ECE start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT for all paths:

ECE=1|𝒫|π𝒫ECEπECE1𝒫subscript𝜋𝒫subscriptECE𝜋\text{ECE}=\frac{1}{|\mathcal{P}|}\sum_{\pi\in\mathcal{P}}\text{ECE}_{\pi}ECE = divide start_ARG 1 end_ARG start_ARG | caligraphic_P | end_ARG ∑ start_POSTSUBSCRIPT italic_π ∈ caligraphic_P end_POSTSUBSCRIPT ECE start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT

7.3 Only predict the best step

Instead of predicting the pe(fπ|x)𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥pe(f_{\pi^{\prime}}|x)italic_p italic_e ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) of each possible future paths πsuperscript𝜋\pi^{\prime}italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and take one step towards the best path, we can instead only predict one value per step. This would greatly reduce the complexity at inference from O(BEE)𝑂superscript𝐵𝐸𝐸O(B^{E}E)italic_O ( italic_B start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT italic_E ) to O(B)𝑂𝐵O(B)italic_O ( italic_B ). To do so, we replace our target pe~(fπ|x)~𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥\tilde{pe}(f_{\pi^{\prime}}|x)over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) to be the best possible value of the whole loss for a given step b𝑏bitalic_b, which is defined as follows:

b=argminπ𝒫πbcπ+pe~(fπ|x) the best loss value we can get for step b.subscriptsuperscript𝑏subscriptargminsuperscript𝜋subscript𝒫𝜋𝑏subscript𝑐superscript𝜋~𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥 the best loss value we can get for step 𝑏\displaystyle\ell^{*}_{b}=\operatorname*{arg\,min}_{\pi^{\prime}\in\mathcal{P}% _{\pi\rightarrow b}}{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb% }{1,0,0}c_{\pi^{\prime}}}+\tilde{pe}(f_{\pi^{\prime}}|x)\text{ the best loss % value we can get for step }b.roman_ℓ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π → italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) the best loss value we can get for step italic_b . (19)

This requires us to search amongst all future paths 𝒫πbsubscript𝒫𝜋𝑏\mathcal{P}_{\pi\rightarrow b}caligraphic_P start_POSTSUBSCRIPT italic_π → italic_b end_POSTSUBSCRIPT and train our gate to directly predict that loss value ^b=hθj(j(x),b)subscriptsuperscript^𝑏subscriptsuperscript𝑗𝜃subscript𝑗𝑥𝑏\widehat{\ell}^{*}_{b}=h^{j}_{\theta}(\mathcal{I}_{j}(x),b)over^ start_ARG roman_ℓ end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = italic_h start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x ) , italic_b ). This moves the O(BEE)𝑂superscript𝐵𝐸𝐸O(B^{E}E)italic_O ( italic_B start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT italic_E ) inference time complexity to train time. At inference, the module predicts one value per step and the the decision is taken accordingly:

gθj()=argminb^b.subscriptsuperscript𝑔𝑗𝜃subscriptargmin𝑏subscriptsuperscript^𝑏\displaystyle g^{j}_{\theta}()=\operatorname*{arg\,min}_{b}\widehat{\ell}^{*}_% {b}.italic_g start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT over^ start_ARG roman_ℓ end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT . (20)

Moreover, to train this model, we have to commit to a fixed cost/01 tradeoff ratio as it is integrated into the training through cπsubscript𝑐superscript𝜋{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}c_{\pi^{% \prime}}}italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. This means that we have to retrain a module for each acc/cost point. RMSE of the next best step approach In the “next best step” we directly predict loss value which contains a combination of the cost and the probability of error instead of only predicting probability of error.

This changes the domain of our target:

previous target: pe~(fπ|x)[0,1]~𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥01\displaystyle\tilde{pe}(f_{\pi^{\prime}}|x)\in[0,1]over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) ∈ [ 0 , 1 ] (21)
new target: b=argminπ𝒫πbλcπ+pe~(fπ|x)[λmin(c),λmax(c)].subscriptsuperscript𝑏subscriptargminsuperscript𝜋subscript𝒫𝜋𝑏𝜆subscript𝑐superscript𝜋~𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥𝜆𝑐𝜆𝑐\displaystyle\ell^{*}_{b}=\operatorname*{arg\,min}_{\pi^{\prime}\in\mathcal{P}% _{\pi\rightarrow b}}\lambda c_{\pi^{\prime}}+\tilde{pe}(f_{\pi^{\prime}}|x)\in% [-\lambda\min(c),\lambda\max(c)].roman_ℓ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π → italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_λ italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) ∈ [ - italic_λ roman_min ( italic_c ) , italic_λ roman_max ( italic_c ) ] . (22)

For hhitalic_h ,we are using the same prediction function as before (see Eqn. 15), but without the sigmoid layer as the target is not a probability anymore.

[Uncaptioned image]

The results for this approach on one dataset are presented in Figure 7.3. This proposal is able to reach similar performance at a higher cost, but is quickly outperformed. In the very high-cost regime, the corresponding λ𝜆\lambdaitalic_λ values are set to a very low value in order to give all the importance to the performance term in the loss. This means that in this setting, the next best method is almost predicting the next best pe~(fπ|x)~𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥\tilde{pe}(f_{\pi^{\prime}}|x)over~ start_ARG italic_p italic_e end_ARG ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ):

b=subscriptsuperscript𝑏absent\displaystyle\ell^{*}_{b}=roman_ℓ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = argminπ𝒫πbλcπ+pe(fπ|x)subscriptargminsuperscript𝜋subscript𝒫𝜋𝑏𝜆subscript𝑐superscript𝜋𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥\displaystyle\operatorname*{arg\,min}_{\pi^{\prime}\in\mathcal{P}_{\pi% \rightarrow b}}\lambda c_{\pi^{\prime}}+pe(f_{\pi^{\prime}}|x)start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π → italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_λ italic_c start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + italic_p italic_e ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) (23)
bsubscriptsuperscript𝑏absent\displaystyle\ell^{*}_{b}\approxroman_ℓ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ≈ argminπ𝒫πbpe(fπ|x) if λ is small.subscriptargminsuperscript𝜋subscript𝒫𝜋𝑏𝑝𝑒conditionalsubscript𝑓superscript𝜋𝑥 if 𝜆 is small\displaystyle\operatorname*{arg\,min}_{\pi^{\prime}\in\mathcal{P}_{\pi% \rightarrow b}}pe(f_{\pi^{\prime}}|x)\text{ if }\lambda\text{ is small}.start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_P start_POSTSUBSCRIPT italic_π → italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p italic_e ( italic_f start_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_x ) if italic_λ is small . (24)

However, the performance quickly degrades as we raise the importance of the cost through λ𝜆\lambdaitalic_λ. This could indicate that the learning task of predicting the value of the next best step is considerably harder and that simple NN is unable to solve it.

7.4 Expanded Literature Review

Dynamic neural networks

Dynamic architectures adapt their computational graphs to the input being processed [16, 36]. By adapting their depth (number of layers executed) or width (number of channels or neurons executed) for each sample, dynamic architectures can significantly reduce computation during inference [16].
Early-exit (EE) networks are a widely studied class of dynamic depth architecture where a prediction is obtained at an intermediate layer and subsquent layers are skipped [2, 23, 18, 40, 25]. This is done by augmenting the network with intermediate inference modules at various layers. The idea was first presented in Bolukbasi et al. [2] where the augmented network is trained end-to-end. Huang et al. [23] addresses common issues arising from end-to-end training such as the interference of early inference modules on the performance of later ones by introducing architectural changes such as dense connections. Han et al. [18] further decouples feature extraction from classification by augmenting the network with an entire parallel stream for early classification while using the original backbone network for feature extraction. The two streams interact with each other via various attention mechanisms [18]. However, end-to-end training is inconvenient when working with large foundation models [1, 40] and using a simple threshold is susceptible to miscalibration of earlier inference modules [40]. Post-training EE approaches instead rely on a fixed pre-trained backbone, and explore effective ways to train the inference modules and design more sophisticated gating mechanisms for the exit rule [40, 5, 25]. One of the challenges of early exiting is that the gating mechanism influences the samples reaching each inference modules. Failure to account for this effect can lead to a distribution shift at inference. This particular issue has first been studied by Han et al. [17], Yu et al. [49] using a threshold-based exit mechanism. Regol et al. [40] propose a fixed-backbone EE procedure that uses a trainable gating mechanism leading to performance gains. However, one important drawback of Regol et al. [40] is that the training procedure needs to be repeated for every operating point making it impractical for use-cases where the computation budget changes over time.

Width-wise sample-adaptation can be achieved by selecting a subset of channels in CNN architectures as in Huang et al. [24], Herrmann et al. [20]. A more general approach that is compatible with transformer-based architectures is SuperNets [34, 38, 19] where a sample is dynamically routed through a subset of neurons at inference. Hazimeh et al. [19] insert a differentiable decision tree layer in any neural network to benefit from the inherent conditional computation of decision trees while still being able to train the network end-to-end with backpropagation. Odena et al. [38], Liu and Deng [34] use reinforcement learning and gradient-based optimization to jointly-train the network augmented with controller modules that select the computational route for a sample. Closer to our algorithm, works such as Wang et al. [47], Xia et al. [48] perform both depth and width adaptation via layer-skipping and channel-selection in convolution-based architectures. All these works rely on trainable controllers that are optimized jointly with the underlying network [47, 48] since the sub-networks need to be optimized for the inputs they handle. This makes them incompatible with large pre-trained foundation models. Foundation models are typically trained with very large proprietary datasets for extended periods of time [39, 4, 3] in a self-supervised fashion. The idea of foundation models is to build a dataset-agnostic model that is trained once and can be readily reused on downstream datasets with only minor fine-tuning (typically of the inference head) [39, 3]. For that reason, it is crucial to develop post-training efficiency techniques that do not require training of the backbone.

Quantization

Quantization is an effective way of speeding up inference where weights, gradients and activations of a model are represented at lower bit resolutions [36]. Quantization-aware training (QAT) techniques quantize the network during training [41, 11] while post-training quantization (PTQ) is performed on a trained model with only a small amount of data [41, 11, 7]. This makes PTQ particularly appealing as a width-compression technique when working with foundation models [1, 7]. However, unlike the previously discussed width-adaptive techniques, quantization is typically not input-adaptive [41, 35, 22]. Once a network has been quantized, the same structure is used for both “easy” and “hard” samples, employing unnecessarily high precision for easier samples [35, 22, 43]. Hong et al. [22], Tian et al. [43] consider dynamic mixed-precision quantization for image super-resolution. In that setting, a low-resolution image is split into patches which are super-resolved individually using a neural network. Hong et al. [22] introduce a bit selector that uses metrics of quantization sensitivity for each patch to determine the optimal bit width per layer. Tian et al. [43] replace the bit-selector at inference with a look-up table indexed by the edge score, arguing that the edge score is a reliable measure for patch complexity in super-resolution tasks. Closer to our specific setting, DQNet [35] explores dynamic mixed-precision quantization for image classification. In DQNet, the network is quantized at different resolutions and is then augmented with a small neural network called the bit controller whose function is to determine the bit resolution for subsequent layer blocks. The bit controller is fed with the intermediate feature map of a CNN [35]. While Liu et al. [35], Hong et al. [22], Tian et al. [43] all dynamically adapt the bit precision on a per-sample basis, the fact that they encode the budget in their loss formulation means that the algorithm needs to be retrained for every operating point. They also do not exploit depth-adaptation.

Quantization of early-exit networks

Other works have combined the adaptability of early-exit networks with the efficacy of quantization by quantizing early-exit networks [42, 28]. In Khalilian Gourtani and Meratnia [28] a pre-trained early-exit network is first split into sections and each section is quantized separately using weight-clustering. The quantized network is then fully retrained using knowledge distillation to recover the intermediate classifier’s performance. Saxena and Roy [42] uses a QAT approach where the optimal per-layer weights and activations quantization parameters are learnt during training. While both of these works combine quantization with early-exiting they both propose QAT-like approaches and are thus unsuitable for foundation models. They are also not sample adaptive along their widths as a single static mixed-precision quantization is learnt for all samples [28, 42].