Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: titletoc

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license
arXiv:2312.17199v1 [stat.ML] 28 Dec 2023

Tractable Function-Space Variational Inference in
Bayesian Neural Networks

Tim G. J. Rudner
University of Oxford &Zonghao Chen
University College London &Yee Whye Teh
University of Oxford &Yarin Gal
University of Oxford
Corresponding author. Email: <tim.rudner@cs.ox.ac.uk>.
Abstract

Reliable predictive uncertainty estimation plays an important role in enabling the deployment of neural networks to safety-critical settings. A popular approach for estimating the predictive uncertainty of neural networks is to define a prior distribution over the network parameters, infer an approximate posterior distribution, and use it to make stochastic predictions. However, explicit inference over neural network parameters makes it difficult to incorporate meaningful prior information about the data-generating process into the model. In this paper, we pursue an alternative approach. Recognizing that the primary object of interest in most settings is the distribution over functions induced by the posterior distribution over neural network parameters, we frame Bayesian inference in neural networks explicitly as inferring a posterior distribution over functions and propose a scalable function-space variational inference method that allows incorporating prior information and results in reliable predictive uncertainty estimates. We show that the proposed method leads to state-of-the-art uncertainty estimation and predictive performance on a range of prediction tasks and demonstrate that it performs well on a challenging safety-critical medical diagnosis task in which reliable uncertainty estimation is essential.

1 Introduction

Machine learning models succeed at an increasingly wide range of narrowly defined tasks (Krizhevsky et al., 2012; Mnih et al., 2013; Silver et al., 2016; Jumper et al., 2021) but may fail without warning when used on inputs that are meaningfully different from the data they were trained on (Amodei et al., 2016; Hendrycks et al., 2021; Rudner and Toner, 2021a, b). To deploy machine learning models in safety-critical environments where failures are costly or may endanger human lives, machine learning methods must be reliable and have the ability to ‘fail gracefully.’ A promising tool for incorporating fail-safe mechanisms into machine learning systems, predictive uncertainty quantification allows machine learning models to express their confidence in the correctness of their predictions.

In this paper, we develop a method for obtaining reliable uncertainty estimates in Bayesian neural networks (bnns, Neal (1996)). While bnns have promised to combine the advantages of deep learning and Bayesian inference, existing approaches for approximate inference in bnns fall short of this promise and have been demonstrated to result in approximate posterior predictive distributions that underperform ‘non-Bayesian’ methods both in terms of predictive accuracy and uncertainty quantification—making them of limited use in practice (Ovadia et al., 2019; Foong et al., 2019; Farquhar et al., 2020a; Band et al., 2021). A potential reason for this shortcoming is that commonly used parameter-space inference methods make it difficult to define meaningful priors that effectively incorporate information about the data-generating process into inference.

To avoid this limitation, we follow Sun et al. (2019) and consider a variational objective defined explicitly in terms of distributions over functions induced by distributions over parameters. In contrast to prior works that rely on approximation techniques that prevent such function-space variational objectives to be used with high-dimensional inputs and highly-overparameterized neural networks, we propose a simple estimator of the Kullback-Leibler divergence between distributions over functions that enables us to perform stochastic variational inference. The proposed estimation procedure allows defining priors that explicitly encourage high predictive uncertainty away from the training data as well as priors that reflect relevant information about the task at hand.

We demonstrate that this approach leads to posterior approximations that exhibit significantly improved predictive uncertainty estimates compared to a wide array of state-of-the-art Bayesian and non-Bayesian methods. Figure 1 shows examples of predictive distributions obtained via function-space variational inference on low-dimensional, easy-to-visualize datasets. As can be seen in the figures, the predictive distributions fit the training data well while also exhibiting a high degree of predictive uncertainty in parts of the input space far away from the training data, as desired.

Contributions.  We propose a simple estimation procedure for performing function-space variational inference in bnns. The variational method allows for the incorporation of meaningful prior information about the data-generating process into the inference and produces reliable predictive uncertainty estimates. We perform a thorough empirical evaluation in which we compare the proposed approach to a wide array of competing methods and show that it consistently results in high predictive performance and reliable predictive uncertainty estimates, outperforming other methods in terms of predictive accuracy, robustness to distribution shifts, and uncertainty-based detection of distributionally-shifted data samples. We evaluate the proposed method on standard benchmarking datasets as well as on a safety-critical medical diagnosis task in which reliable uncertainty estimation is essential.111Our code can be accessed at https://github.com/timrudner/FSVI.

Refer to caption
(a) Predictive Distribution
Refer to caption
(b) Predictive Mean     
Refer to caption
(c) Predictive Variance    
Figure 1: 1D regression on the Snelson dataset and binary classification on the Two Moons dataset. The plots show the predictive distributions of a bnns, obtained via function-space variational inference (fsvi). For further illustrative exampled and comparisons to deep ensembles and bnns learned via parameter-space variational inference, see Appendix B.

2 Preliminaries

We consider supervised learning tasks on data 𝒟=˙{(𝐱n,𝐲n)}n=1N=(𝐗𝒟,𝐲𝒟)𝒟˙superscriptsubscriptsubscript𝐱𝑛subscript𝐲𝑛𝑛1𝑁subscript𝐗𝒟subscript𝐲𝒟\mathcal{D}\,\dot{=}\,\{(\mathbf{x}_{n},\mathbf{y}_{n})\}_{n=1}^{N}=(\mathbf{X% }_{\mathcal{D}},\mathbf{y}_{\mathcal{D}})caligraphic_D over˙ start_ARG = end_ARG { ( bold_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , bold_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT = ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT , bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ) with inputs 𝐱n𝒳D\mathbf{x}_{n}\in\mathcal{X}\subseteq{}^{D}bold_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ caligraphic_X ⊆ start_FLOATSUPERSCRIPT italic_D end_FLOATSUPERSCRIPT and targets 𝐲n𝒴subscript𝐲𝑛𝒴\mathbf{y}_{n}\in\mathcal{Y}bold_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ caligraphic_Y, where 𝒴Q\mathcal{Y}\subseteq{}^{Q}caligraphic_Y ⊆ start_FLOATSUPERSCRIPT italic_Q end_FLOATSUPERSCRIPT for regression and 𝒴{0,1}Q𝒴superscript01𝑄\mathcal{Y}\subseteq\{0,1\}^{Q}caligraphic_Y ⊆ { 0 , 1 } start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT for classification tasks. Bayesian neural networks (bnns) are stochastic neural networks trained using (approximate) Bayesian inference. Denoting the parameters of such a stochastic neural network by the multivariate random variable 𝚯P𝚯superscript𝑃\bm{\Theta}\in\mathbb{R}^{P}bold_Θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT and letting the function mapping defined by a neural network architecture be given by f:𝒳×PQ:𝑓𝒳superscript𝑃superscript𝑄f:\mathcal{X}\times\mathbb{R}^{P}\rightarrow\mathbb{R}^{Q}italic_f : caligraphic_X × blackboard_R start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT, we obtain a random function f(;𝚯)𝑓𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ). For a parameter realization 𝜽𝜽{\bm{\theta}}bold_italic_θ, we obtain a corresponding function realization, f(;𝜽)𝑓𝜽f(\cdot\,;{\bm{\theta}})italic_f ( ⋅ ; bold_italic_θ ). When evaluated at a finite collection of points 𝐗={𝐱i}i=1m𝐗superscriptsubscriptsubscript𝐱𝑖𝑖1𝑚\mathbf{X}=\{\mathbf{x}_{i}\}_{i=1}^{m}bold_X = { bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT, f(𝐗;𝚯)𝑓𝐗𝚯f(\mathbf{X};\bm{\Theta})italic_f ( bold_X ; bold_Θ ) is a multivariate random variable and f(𝐗;𝜽)𝑓𝐗𝜽f(\mathbf{X};{\bm{\theta}})italic_f ( bold_X ; bold_italic_θ ) is a vector.

Letting p𝐲|f(𝐗;𝚯)subscript𝑝conditional𝐲𝑓𝐗𝚯p_{\mathbf{y}|f(\mathbf{X};\bm{\Theta})}italic_p start_POSTSUBSCRIPT bold_y | italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT be a likelihood function and p𝐲|f(𝐗;𝚯)(𝐲𝒟|f(𝐗𝒟;𝜽))subscript𝑝conditional𝐲𝑓𝐗𝚯conditionalsubscript𝐲𝒟𝑓subscript𝐗𝒟𝜽p_{\mathbf{y}|f(\mathbf{X};\bm{\Theta})}(\mathbf{y}_{\mathcal{D}}\,|\,f(% \mathbf{X}_{\mathcal{D}};{\bm{\theta}}))italic_p start_POSTSUBSCRIPT bold_y | italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ; bold_italic_θ ) ) be the likelihood of observing the targets 𝐲𝒟subscript𝐲𝒟\mathbf{y}_{\mathcal{D}}bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT under the stochastic function f(;𝚯)𝑓𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) evaluated at inputs 𝐗𝒟subscript𝐗𝒟\mathbf{X}_{\mathcal{D}}bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT and letting p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT be a prior distribution over the stochastic network parameters 𝚯𝚯\bm{\Theta}bold_Θ, we can use Bayes’ Theorem to find the posterior distribution, p𝚯|𝒟subscript𝑝conditional𝚯𝒟p_{\bm{\Theta}|\mathcal{D}}italic_p start_POSTSUBSCRIPT bold_Θ | caligraphic_D end_POSTSUBSCRIPT (MacKay, 1992; Neal, 1996). However, since the mapping f𝑓fitalic_f is a nonlinear function of the stochastic parameters 𝚯𝚯\bm{\Theta}bold_Θ, exact inference is analytically intractable. Variational inference is an approach that seeks to sidestep this intractability by framing posterior inference as a variational optimization problem, where the goal is to find a distribution q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT in a variational family 𝒬q𝚯subscript𝒬subscript𝑞𝚯\mathcal{Q}_{q_{\bm{\Theta}}}caligraphic_Q start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT that solves the variational problem minq𝚯𝒬qΘ𝔻KL(q𝚯p𝚯|𝒟)subscriptsubscript𝑞𝚯subscript𝒬subscript𝑞Θsubscript𝔻KLconditionalsubscript𝑞𝚯subscript𝑝conditional𝚯𝒟\min_{q_{\bm{\Theta}}\in\mathcal{Q}_{q_{\Theta}}}\mathbb{D}_{\textrm{KL}}(q_{% \bm{\Theta}}\;\|\;p_{\bm{\Theta}|\mathcal{D}})roman_min start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ∈ caligraphic_Q start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT roman_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT bold_Θ | caligraphic_D end_POSTSUBSCRIPT ) (Wainwright and Jordan, 2008). If 𝒬q𝚯subscript𝒬subscript𝑞𝚯\mathcal{Q}_{q_{\bm{\Theta}}}caligraphic_Q start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the family of mean-field Gaussian distributions and the prior distribution over parameters p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT given by a diagonal Gaussian distribution, the resulting variational objective is amenable to stochastic variational inference and can be optimized using gradient-based methods (Hinton and van Camp, 1993; Graves, 2011; Hoffman et al., 2013; Blundell et al., 2015).

2.1 A Function-Space Perspective on Variational Inference in Bayesian Neural Networks

Instead of seeking to infer an approximate posterior distribution over parameters, we frame variational inference in stochastic neural networks as inferring an approximation to the posterior distribution over functions pf(;𝚯)|𝒟subscript𝑝conditional𝑓𝚯𝒟p_{f(\cdot\,;\bm{\Theta})|\mathcal{D}}italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) | caligraphic_D end_POSTSUBSCRIPT induced by the posterior distribution over parameters p𝚯|𝒟subscript𝑝conditional𝚯𝒟p_{\bm{\Theta}|\mathcal{D}}italic_p start_POSTSUBSCRIPT bold_Θ | caligraphic_D end_POSTSUBSCRIPT, that is,

pf(;𝚯)|𝒟(f(;𝜽)|𝒟)=Pp𝚯|𝒟(𝜽|𝒟)δ(f(;𝜽)f(;𝜽))d𝜽,subscript𝑝conditional𝑓𝚯𝒟conditional𝑓𝜽𝒟subscriptsuperscript𝑃subscript𝑝conditional𝚯𝒟conditionalsuperscript𝜽𝒟𝛿𝑓𝜽𝑓superscript𝜽dsuperscript𝜽\displaystyle\SwapAboveDisplaySkip p_{f(\cdot\,;\bm{\Theta})|\mathcal{D}}(f(% \cdot\,;{\bm{\theta}})\,|\,\mathcal{D})=\int_{\mathbb{R}^{P}}p_{\bm{\Theta}|% \mathcal{D}}({\bm{\theta}}^{\prime}\,|\,\mathcal{D})\,\delta(f(\cdot\,;{\bm{% \theta}})-f(\cdot\,;{\bm{\theta}}^{\prime}))\,\textrm{d}{\bm{\theta}}^{\prime},italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) | caligraphic_D end_POSTSUBSCRIPT ( italic_f ( ⋅ ; bold_italic_θ ) | caligraphic_D ) = ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT bold_Θ | caligraphic_D end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | caligraphic_D ) italic_δ ( italic_f ( ⋅ ; bold_italic_θ ) - italic_f ( ⋅ ; bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) d bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , (1)

where δ()𝛿\delta(\cdot)italic_δ ( ⋅ ) is the Dirac delta function (Wolpert, 1993). Considering the prior distribution over functions pf(;𝚯)subscript𝑝𝑓𝚯p_{f(\cdot\,;\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT induced by a prior distribution over parameters p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT,

pf(;𝚯)(f(;𝜽))=Pp𝚯(𝜽)δ(f(;𝜽)f(;𝜽))d𝜽,subscript𝑝𝑓𝚯𝑓𝜽subscriptsuperscript𝑃subscript𝑝𝚯superscript𝜽𝛿𝑓𝜽𝑓superscript𝜽dsuperscript𝜽\displaystyle\SwapAboveDisplaySkip p_{f(\cdot\,;\bm{\Theta})}(f(\cdot\,;{\bm{% \theta}}))=\int_{\mathbb{R}^{P}}p_{\bm{\Theta}}({\bm{\theta}}^{\prime})\,% \delta(f(\cdot\,;{\bm{\theta}})-f(\cdot\,;{\bm{\theta}}^{\prime}))\,\textrm{d}% {\bm{\theta}}^{\prime},italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ( italic_f ( ⋅ ; bold_italic_θ ) ) = ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_δ ( italic_f ( ⋅ ; bold_italic_θ ) - italic_f ( ⋅ ; bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) d bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , (2)

and the variational distribution over functions qf(;𝚯)subscript𝑞𝑓𝚯q_{f(\cdot\,;\bm{\Theta})}italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT induced by a variational distribution over parameters q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT,

qf(;𝚯)(f(;𝜽))=Pq𝚯(𝜽)δ(f(;𝜽)f(;𝜽))d𝜽,subscript𝑞𝑓𝚯𝑓𝜽subscriptsuperscript𝑃subscript𝑞𝚯superscript𝜽𝛿𝑓𝜽𝑓superscript𝜽dsuperscript𝜽\displaystyle\SwapAboveDisplaySkip q_{f(\cdot\,;\bm{\Theta})}(f(\cdot\,;{\bm{% \theta}}))=\int_{\mathbb{R}^{P}}q_{\bm{\Theta}}({{\bm{\theta}}^{\prime}})\,% \delta(f(\cdot\,;{\bm{\theta}})-f(\cdot\,;{\bm{\theta}}^{\prime}))\,\textrm{d}% {\bm{\theta}}^{\prime},italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ( italic_f ( ⋅ ; bold_italic_θ ) ) = ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_δ ( italic_f ( ⋅ ; bold_italic_θ ) - italic_f ( ⋅ ; bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) d bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , (3)

we can express the problem of finding a posterior distribution over functions variationally as

minq𝚯𝒬q𝚯𝔻KL(qf(;𝚯)pf(;𝚯)|𝒟),subscriptsubscript𝑞𝚯subscript𝒬subscript𝑞𝚯subscript𝔻KLconditionalsubscript𝑞𝑓𝚯subscript𝑝conditional𝑓𝚯𝒟\displaystyle\min_{q_{\bm{\Theta}}\in\mathcal{Q}_{q_{\bm{\Theta}}}}\mathbb{D}_% {\textrm{KL}}(q_{f(\cdot\,;\bm{\Theta})}\,\|\,p_{f(\cdot\,;\bm{\Theta})|% \mathcal{D}}),roman_min start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ∈ caligraphic_Q start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) | caligraphic_D end_POSTSUBSCRIPT ) , (4)

which allows us to effectively incorporate meaningful prior information about the underlying data-generating process into training. As discussed by Burt et al. (2021), this variational objective is guaranteed to be well-defined for suitably chosen prior distributions over functions. Specifically, the KL divergence between two distributions over functions generated from different distributions over parameters applied to the same mapping (e.g., the same neural network architecture) is well-defined (i.e., finite) if the KL divergence between the distributions over parameters is finite, since, by the strong data processing inequality (Polyanskiy and Wu, 2017),

𝔻KL(qf(;𝚯)pf(;𝚯))𝔻KL(q𝚯p𝚯).subscript𝔻KLconditionalsubscript𝑞𝑓𝚯subscript𝑝𝑓𝚯subscript𝔻KLconditionalsubscript𝑞𝚯subscript𝑝𝚯\displaystyle\mathbb{D}_{\textrm{KL}}(q_{f(\cdot\,;\bm{\Theta})}\,\|\,p_{f(% \cdot;\bm{\Theta})})\leq\mathbb{D}_{\textrm{KL}}(q_{\bm{\Theta}}\,\|\,p_{\bm{% \Theta}}).blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ) ≤ blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ) . (5)

As a result, if 𝔻KL(q𝚯p𝚯)<subscript𝔻KLconditionalsubscript𝑞𝚯subscript𝑝𝚯\mathbb{D}_{\textrm{KL}}(q_{\bm{\Theta}}\,\|\,p_{\bm{\Theta}})<\inftyblackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ) < ∞, which is the case for finite-dimensional parameter vectors 𝚯𝚯\bm{\Theta}bold_Θ and q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT absolutely continuous with respect to p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT, then the function-space KL divergence is finite and thus well-defined as a variational objective.

Hence, for a likelihood function defined on a finite set of training targets 𝐲𝒟subscript𝐲𝒟\mathbf{y}_{\mathcal{D}}bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT and a suitably defined prior distribution over functions, we can express the variational problem above equivalently as the well-defined maximization problem maxq𝚯𝒬𝜽(q𝚯)subscriptsubscript𝑞𝚯subscript𝒬𝜽subscript𝑞𝚯\max_{q_{\bm{\Theta}}\in\mathcal{Q}_{{\bm{\theta}}}}\mathcal{F}(q_{\bm{\Theta}})roman_max start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ∈ caligraphic_Q start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_F ( italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ) with

(q𝚯)=˙𝔼qf(𝐗𝒟;𝚯)[logp𝐲|f(𝐗;𝚯)(𝐲𝒟|f(𝐗𝒟;𝜽))]𝔻KL(qf(;𝚯)pf(;𝚯)),subscript𝑞𝚯˙subscript𝔼subscript𝑞𝑓subscript𝐗𝒟𝚯delimited-[]subscript𝑝conditional𝐲𝑓𝐗𝚯conditionalsubscript𝐲𝒟𝑓subscript𝐗𝒟𝜽subscript𝔻KLconditionalsubscript𝑞𝑓𝚯subscript𝑝𝑓𝚯\displaystyle\begin{split}\mathcal{F}(q_{\bm{\Theta}})&\,\dot{=}\,\mathbb{E}_{% q_{f(\mathbf{X}_{\mathcal{D}};\bm{\Theta})}}[\log p_{\mathbf{y}|f(\mathbf{X};% \bm{\Theta})}(\mathbf{y}_{\mathcal{D}}\,|\,f(\mathbf{X}_{\mathcal{D}};{\bm{% \theta}}))]-\mathbb{D}_{\textrm{KL}}(q_{f(\cdot;\bm{\Theta})}\,\|\,p_{f(\cdot;% \bm{\Theta})}),\end{split}start_ROW start_CELL caligraphic_F ( italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ) end_CELL start_CELL over˙ start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ; bold_Θ ) end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ; bold_italic_θ ) ) ] - blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ) , end_CELL end_ROW (6)

where 𝔻KL(qf(;𝚯)pf(;𝚯))subscript𝔻KLconditionalsubscript𝑞𝑓𝚯subscript𝑝𝑓𝚯\mathbb{D}_{\textrm{KL}}(q_{f(\cdot;\bm{\Theta})}\,\|\,p_{f(\cdot;\bm{\Theta})})blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ) is also a KL divergence between distributions over functions.

Unfortunately, evaluating the KL divergence in Equation 6 is in general intractable for arbitrary mappings f𝑓fitalic_f. To obtain a tractable objective, Sun et al. (2019) showed that 𝔻KL(qf(;𝚯)pf(;𝚯))subscript𝔻KLconditionalsubscript𝑞𝑓𝚯subscript𝑝𝑓𝚯\mathbb{D}_{\textrm{KL}}(q_{f(\cdot;\bm{\Theta})}\,\|\,p_{f(\cdot;\bm{\Theta})})blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT ) can be expressed as the supremum of the KL divergence from qf(;𝚯)subscript𝑞𝑓𝚯q_{f(\cdot;\bm{\Theta})}italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT to pf(;𝚯)subscript𝑝𝑓𝚯p_{f(\cdot;\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT over all finite sets of evaluation points, resulting in the objective function

(q𝚯)=𝔼qf(𝐗𝒟;𝚯)[logp𝐲|f(𝐗;𝚯)(𝐲𝒟|f(𝐗𝒟;𝜽))]sup𝐗𝒳𝔻KL(qf(𝐗;𝚯)pf(𝐗;𝚯)),subscript𝑞𝚯subscript𝔼subscript𝑞𝑓subscript𝐗𝒟𝚯delimited-[]subscript𝑝conditional𝐲𝑓𝐗𝚯conditionalsubscript𝐲𝒟𝑓subscript𝐗𝒟𝜽subscriptsupremum𝐗subscript𝒳subscript𝔻KLconditionalsubscript𝑞𝑓𝐗𝚯subscript𝑝𝑓𝐗𝚯\displaystyle\begin{split}\mathcal{F}(q_{\bm{\Theta}})=\mathbb{E}_{q_{f(% \mathbf{X}_{\mathcal{D}};\bm{\Theta})}}[\log p_{\mathbf{y}|f(\mathbf{X};\bm{% \Theta})}(\mathbf{y}_{\mathcal{D}}\,|\,f(\mathbf{X}_{\mathcal{D}};{\bm{\theta}% }))]-\sup_{\mathbf{X}\in\mathcal{X}_{\mathbb{N}}}\mathbb{D}_{\textrm{KL}}(q_{f% (\mathbf{X};\bm{\Theta})}\,\|\,p_{f(\mathbf{X};\bm{\Theta})}),\end{split}start_ROW start_CELL caligraphic_F ( italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ; bold_Θ ) end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ; bold_italic_θ ) ) ] - roman_sup start_POSTSUBSCRIPT bold_X ∈ caligraphic_X start_POSTSUBSCRIPT blackboard_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ) , end_CELL end_ROW (7)

where 𝒳=˙n{𝐗𝒳n|𝒳nn×D}subscript𝒳˙subscript𝑛conditional-set𝐗subscript𝒳𝑛subscript𝒳𝑛superscript𝑛𝐷\mathcal{X}_{\mathbb{N}}\,\dot{=}\,\bigcup_{n\in\mathbb{N}}\{\mathbf{X}\in% \mathcal{X}_{n}\,|\,\mathcal{X}_{n}\subseteq\mathbb{R}^{n\times D}\}caligraphic_X start_POSTSUBSCRIPT blackboard_N end_POSTSUBSCRIPT over˙ start_ARG = end_ARG ⋃ start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT { bold_X ∈ caligraphic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT | caligraphic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ⊆ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_D end_POSTSUPERSCRIPT } is the collection of all finite sets of evaluation points. However, this objective function is still challenging to optimize in practice: The supremum cannot be obtained analytically and the KL divergence term itself is analytically intractable and difficult to estimate in high dimensions—even for a single evaluation point.

In the next section, we will describe an approximation and estimation procedure that allows scaling function-space variational inference to large neural networks and high-dimensional input data.

3 Deriving a Tractable Function-Space Variational Objective

The primary obstacle to computing the objective in Equation 6 is the KL divergence from qf(;𝚯)subscript𝑞𝑓𝚯q_{f(\cdot;\bm{\Theta})}italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT to pf(;𝚯)subscript𝑝𝑓𝚯p_{f(\cdot;\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT. There are two reasons why the KL divergence in Equation 7 is intractable: First, for bnns or other non-linear models, we do not have access to the probability density functions of the multivariate distributions qf(𝐗;𝚯)subscript𝑞𝑓𝐗𝚯q_{f(\mathbf{X};\bm{\Theta})}italic_q start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT and pf(𝐗;𝚯)subscript𝑝𝑓𝐗𝚯p_{f(\mathbf{X};\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT; second, for all but extremely simple input spaces, we are unable to compute the supremum over all possible finite sets of evaluation points. In the remainder of this section, we outline an approach for obtaining an estimator of a locally accurate approximation to the KL divergence that allows for scalable gradient-based optimization of Equation 7.

We first approach the problem of computing the KL divergence between two bnns evaluated at a finite set of points. To do so, we first derive tractable approximations to the distributions over functions qf(𝐗;𝚯)subscript𝑞𝑓𝐗𝚯q_{f(\mathbf{X};\bm{\Theta})}italic_q start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT and pf(𝐗;𝚯)subscript𝑝𝑓𝐗𝚯p_{f(\mathbf{X};\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT Next, we show that under these approximations, we are able to obtain a closed-form approximation to the KL divergence and describe a simple Monte Carlo estimator of the supremum in the function-space KL divergence.

3.1 Approximating Distributions over Functions via Local Linearization

To obtain an approximation to the probability distributions of qf(𝐗;𝚯)subscript𝑞𝑓𝐗𝚯q_{f(\mathbf{X};\bm{\Theta})}italic_q start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT and pf(𝐗;𝚯)subscript𝑝𝑓𝐗𝚯p_{f(\mathbf{X};\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT, we use a first-order Taylor expansion of the mapping f𝑓fitalic_f about the mean parameters of q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT and p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT, respectively, and derive the induced distributions under the linearized mapping.

For a stochastic function f(;𝚯)𝑓𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) defined in terms of stochastic parameters 𝚯𝚯\bm{\Theta}bold_Θ distributed according to distribution g𝚯subscript𝑔𝚯g_{\bm{\Theta}}italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT with 𝐦=˙𝔼g𝚯[𝚯]𝐦˙subscript𝔼subscript𝑔𝚯𝚯\mathbf{m}\,\dot{=}\,\operatorname{\mathbb{E}}_{g_{\bm{\Theta}}}[\bm{\Theta}]bold_m over˙ start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ bold_Θ ] and 𝐒=˙Covg𝚯[𝚯]𝐒˙subscriptCovsubscript𝑔𝚯delimited-[]𝚯\mathbf{S}\,\dot{=}\,\text{Cov}_{g_{\bm{\Theta}}}[\bm{\Theta}]bold_S over˙ start_ARG = end_ARG Cov start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ bold_Θ ], we denote the linearization of the stochastic function f(;𝚯)𝑓𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) about 𝐦𝐦\mathbf{m}bold_m by

f(;𝚯)f~(;𝐦,𝚯)=˙f(;𝐦)+𝒥(;𝐦)(𝚯𝐦),𝑓𝚯~𝑓𝐦𝚯˙𝑓𝐦𝒥𝐦𝚯𝐦\displaystyle f(\cdot\,;\bm{\Theta})\approx\smash{\tilde{f}}(\cdot\,;\mathbf{m% },\bm{\Theta})\,\dot{=}\,f(\cdot\,;\mathbf{m})+\mathcal{J}(\cdot\,;\mathbf{m})% (\bm{\Theta}-\mathbf{m}),italic_f ( ⋅ ; bold_Θ ) ≈ over~ start_ARG italic_f end_ARG ( ⋅ ; bold_m , bold_Θ ) over˙ start_ARG = end_ARG italic_f ( ⋅ ; bold_m ) + caligraphic_J ( ⋅ ; bold_m ) ( bold_Θ - bold_m ) , (8)

where 𝒥(;𝐦)=˙(f(;𝚯)/𝚯)|𝚯=𝐦evaluated-at𝒥𝐦˙𝑓𝚯𝚯𝚯𝐦\mathcal{J}(\cdot\,;\mathbf{m})\,\dot{=}\,(\partial f(\cdot\,;\bm{\Theta})/% \partial\bm{\Theta})|_{\bm{\Theta}=\mathbf{m}}caligraphic_J ( ⋅ ; bold_m ) over˙ start_ARG = end_ARG ( ∂ italic_f ( ⋅ ; bold_Θ ) / ∂ bold_Θ ) | start_POSTSUBSCRIPT bold_Θ = bold_m end_POSTSUBSCRIPT is the Jacobian of f(;𝚯)𝑓𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) evaluated at 𝚯=𝐦𝚯𝐦\bm{\Theta}=\mathbf{m}bold_Θ = bold_m, and the mean and covariance of the distribution over the linearized mapping f~~𝑓\smash{\tilde{f}}over~ start_ARG italic_f end_ARG at 𝐗,𝐗𝒳𝐗superscript𝐗𝒳\mathbf{X},\mathbf{X}^{\prime}\in\mathcal{X}bold_X , bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_X are given by

𝔼[f~(𝐗;𝚯)]𝔼~𝑓𝐗𝚯\displaystyle\operatorname{\mathbb{E}}[\smash{\tilde{f}}(\mathbf{X};\bm{\Theta% })]blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) ] =f(𝐗;𝐦)absent𝑓𝐗𝐦\displaystyle=f(\mathbf{X};\mathbf{m})= italic_f ( bold_X ; bold_m ) (9)
Cov[f~(𝐗;𝚯),f~(𝐗;𝚯)]Cov~𝑓𝐗𝚯~𝑓superscript𝐗𝚯\displaystyle\textrm{{Cov}}[\smash{\tilde{f}}(\mathbf{X};\bm{\Theta}),\smash{% \tilde{f}}(\mathbf{X}^{\prime};\bm{\Theta})]Cov [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) , over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_Θ ) ] =𝒥(𝐗;𝐦)𝐒𝒥(𝐗,𝐦).absent𝒥𝐗𝐦𝐒𝒥superscriptsuperscript𝐗𝐦top\displaystyle=\mathcal{J}(\mathbf{X};\mathbf{m})\mathbf{S}\mathcal{J}(\mathbf{% X}^{\prime},\mathbf{m})^{\top}.= caligraphic_J ( bold_X ; bold_m ) bold_S caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (10)

For a derivation of this result, see Appendix A. Since Gaussianity is preserved under affine transformations, if g𝚯subscript𝑔𝚯g_{\bm{\Theta}}italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT is a multivariate Gaussian distribution with mean 𝐦𝐦\mathbf{m}bold_m and diagonal co-variance 𝐒𝐒\mathbf{S}bold_S, then the distribution g~~𝑔\tilde{g}over~ start_ARG italic_g end_ARG over f~(𝐗;𝚯)~𝑓𝐗𝚯\smash{\tilde{f}}(\mathbf{X}\,;\bm{\Theta})over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) is given by

g~f~(𝐗;𝐦,𝚯)=𝒩(f(𝐗;𝐦),𝒥(𝐗;𝐦)𝐒𝒥(𝐗;𝐦)).subscript~𝑔~𝑓𝐗𝐦𝚯𝒩𝑓𝐗𝐦𝒥𝐗𝐦𝐒𝒥superscript𝐗𝐦top\displaystyle\tilde{g}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})}=% \mathcal{N}(f(\mathbf{X};\mathbf{m}),\mathcal{J}(\mathbf{X};\mathbf{m})\mathbf% {S}\mathcal{J}(\mathbf{X};\mathbf{m})^{\top}).over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT = caligraphic_N ( italic_f ( bold_X ; bold_m ) , caligraphic_J ( bold_X ; bold_m ) bold_S caligraphic_J ( bold_X ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (11)

For stochastic functions parameterized by many millions of parameters, obtaining the covariance of g~f~(𝐗;𝚯)subscript~𝑔~𝑓𝐗𝚯\tilde{g}_{\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT—which requires computing an inner product of two Jacobian matrices—can be computationally expensive. Instead of computing the distribution over the linearized mapping exactly, we can construct a suitable Monte Carlo estimator. To do so, we consider a partition of the set of parameters into sets α𝛼\alphaitalic_α and β𝛽\betaitalic_β (with |β||α|much-less-than𝛽𝛼|\beta|\ll|\alpha|| italic_β | ≪ | italic_α |) and note that the linearized mapping can then be expressed as

f~(;𝐦,𝚯)=f(;𝐦)+f~α(;𝐦,𝚯α)+𝒥β(;𝐦)(𝚯β𝐦β),~𝑓𝐦𝚯𝑓𝐦subscript~𝑓𝛼𝐦subscript𝚯𝛼subscript𝒥𝛽𝐦subscript𝚯𝛽subscript𝐦𝛽\displaystyle\smash{\tilde{f}}(\cdot\,;\mathbf{m},\bm{\Theta})=f(\cdot\,;% \mathbf{m})+\smash{\tilde{f}}_{\alpha}(\cdot\,;\mathbf{m},\bm{\Theta}_{\alpha}% )+\mathcal{J}_{\beta}(\cdot\,;\mathbf{m})(\bm{\Theta}_{\beta}-\mathbf{m}_{% \beta}),over~ start_ARG italic_f end_ARG ( ⋅ ; bold_m , bold_Θ ) = italic_f ( ⋅ ; bold_m ) + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) + caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( ⋅ ; bold_m ) ( bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT - bold_m start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) , (12)

with

f~α(;𝐦,𝚯α)=˙𝒥α(;𝐦)(𝚯α𝐦α),subscript~𝑓𝛼𝐦subscript𝚯𝛼˙subscript𝒥𝛼𝐦subscript𝚯𝛼subscript𝐦𝛼\displaystyle\SwapAboveDisplaySkip\smash{\tilde{f}}_{\alpha}(\cdot\,;\mathbf{m% },\bm{\Theta}_{\alpha})\,\dot{=}\,\mathcal{J}_{\alpha}(\cdot\,;\mathbf{m})(\bm% {\Theta}_{\alpha}-\mathbf{m}_{\alpha}),over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) over˙ start_ARG = end_ARG caligraphic_J start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m ) ( bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - bold_m start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) , (13)

where 𝒥α(;𝐦)subscript𝒥𝛼𝐦\mathcal{J}_{\alpha}(\cdot\,;\mathbf{m})caligraphic_J start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m ) and 𝒥β(;𝐦)subscript𝒥𝛽𝐦\mathcal{J}_{\beta}(\cdot\,;\mathbf{m})caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( ⋅ ; bold_m ) are the columns of the Jacobian matrix corresponding to the sets of parameters α𝛼\alphaitalic_α and β𝛽\betaitalic_β, respectively, and 𝚯αsubscript𝚯𝛼\bm{\Theta}_{\alpha}bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT and 𝚯βsubscript𝚯𝛽\bm{\Theta}_{\beta}bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT are the corresponding random parameter vectors. Noting that Equation 12 expresses f~~𝑓\smash{\tilde{f}}over~ start_ARG italic_f end_ARG as a sum of (affine transformations of) random variables, we can use the fact that for independent Gaussian random variables 𝐗𝐗\mathbf{X}bold_X and 𝐘𝐘\mathbf{Y}bold_Y, the distribution h𝐙subscript𝐙h_{\mathbf{Z}}italic_h start_POSTSUBSCRIPT bold_Z end_POSTSUBSCRIPT of 𝐙=𝐗+𝐘𝐙𝐗𝐘\mathbf{Z}=\mathbf{X}+\mathbf{Y}bold_Z = bold_X + bold_Y is equal to the convolution of the distributions h𝐗subscript𝐗h_{\mathbf{X}}italic_h start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT and h𝐘subscript𝐘h_{\mathbf{Y}}italic_h start_POSTSUBSCRIPT bold_Y end_POSTSUBSCRIPT to obtain an approximation to f~~𝑓\smash{\tilde{f}}over~ start_ARG italic_f end_ARG. In particular, we can show that if g𝚯subscript𝑔𝚯g_{\bm{\Theta}}italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT is a multivariate Gaussian distribution with 𝚯α𝚯βperpendicular-tosubscript𝚯𝛼subscript𝚯𝛽\bm{\Theta}_{\alpha}\perp\bm{\Theta}_{\beta}bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ⟂ bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT, the distribution g~f~(𝐗;𝚯)subscript~𝑔~𝑓𝐗𝚯\tilde{g}_{\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT can be approximated by the Monte Carlo estimator

g~^f~(𝐗;𝐦,𝚯)=1Rj=1R𝒩(f(𝐗;𝐦)+f~α(𝐗;𝐦,𝚯α)(j),𝒥β(𝐗;𝐦)𝐒β𝒥β(𝐗;𝐦)),subscript^~𝑔~𝑓𝐗𝐦𝚯1𝑅superscriptsubscript𝑗1𝑅𝒩𝑓𝐗𝐦subscript~𝑓𝛼superscript𝐗𝐦subscript𝚯𝛼𝑗subscript𝒥𝛽𝐗𝐦subscript𝐒𝛽subscript𝒥𝛽superscript𝐗𝐦top\displaystyle\hat{\tilde{g}}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{% \Theta})}=\frac{1}{R}\sum\nolimits_{j=1}^{R}\mathcal{N}\Big{(}f(\mathbf{X};% \mathbf{m})+\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{% \alpha})^{(j)},\mathcal{J}_{\beta}(\mathbf{X};\mathbf{m})\mathbf{S}_{\beta}{% \mathcal{J}_{\beta}(\mathbf{X};\mathbf{m})}^{\top}\Big{)},over^ start_ARG over~ start_ARG italic_g end_ARG end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_R end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT caligraphic_N ( italic_f ( bold_X ; bold_m ) + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT , caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m ) bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) , (14)

where g𝚯β=𝒩(𝐦β,𝐒β)subscript𝑔subscript𝚯𝛽𝒩subscript𝐦𝛽subscript𝐒𝛽g_{\bm{\Theta}_{\beta}}=\mathcal{N}(\mathbf{m}_{\beta},\mathbf{S}_{\beta})italic_g start_POSTSUBSCRIPT bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( bold_m start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) and samples f~α(𝐗;𝐦,𝚯α)(j)subscript~𝑓𝛼superscript𝐗𝐦subscript𝚯𝛼𝑗\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{\alpha})^{(j)}over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT are obtained by sampling parameters from the distribution g𝚯α=𝒩(𝐦α,𝐒α)subscript𝑔subscript𝚯𝛼𝒩subscript𝐦𝛼subscript𝐒𝛼g_{\bm{\Theta}_{\alpha}}=\mathcal{N}(\mathbf{m}_{\alpha},\mathbf{S}_{\alpha})italic_g start_POSTSUBSCRIPT bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( bold_m start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ). For a derivation of this result, see Appendix A. This estimator is biased for finite K𝐾Kitalic_K but converges to g~f~(𝐗;𝐦,𝚯)subscript~𝑔~𝑓𝐗𝐦𝚯\tilde{g}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT as R𝑅R\rightarrow\inftyitalic_R → ∞. Similarly, for finite R𝑅Ritalic_R, the smaller [𝐒α]iisubscriptdelimited-[]subscript𝐒𝛼𝑖𝑖[\mathbf{S}_{\alpha}]_{ii}[ bold_S start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT, the more accurate and less biased the estimator will be. In our empirical evaluation, we use a single Monte Carlo sample, R=1𝑅1R=1italic_R = 1, to preserve Gaussianity and choose α𝛼\alphaitalic_α to be the set of parameters in neural network layers 1:L1:1𝐿11:L-11 : italic_L - 1 and β𝛽\betaitalic_β to be the set of parameters in the final neural network layer.

3.2 Approximating the Function-Space Kullback-Leibler Divergence

From Section 3.1, we know that if q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT and p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT are both Gaussian distributions, then the induced distributions under the linearized mapping f~~𝑓\smash{\tilde{f}}over~ start_ARG italic_f end_ARG evaluated at a finite set of evaluation points will be Gaussian as well. This means that for Gaussian variational and prior distributions over 𝚯𝚯\bm{\Theta}bold_Θ, we can obtain locally accurate approximations to the induced distributions qf(;𝚯)subscript𝑞𝑓𝚯q_{f(\cdot;\bm{\Theta})}italic_q start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT to pf(;𝚯)subscript𝑝𝑓𝚯p_{f(\cdot;\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT and use them to approximate the KL divergence in the variational objective by 𝔻KL(q~f~(𝐗;𝚯)p~f~(𝐗;𝚯))subscript𝔻KLconditionalsubscript~𝑞~𝑓𝐗𝚯subscript~𝑝~𝑓𝐗𝚯\mathbb{D}_{\textrm{KL}}(\smash{\tilde{q}}_{\smash{\tilde{f}}(\mathbf{X};\bm{% \Theta})}\,\|\,\smash{\tilde{p}}_{\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})})blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ∥ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ). Moreover, for an isotropic Gaussian prior and a mean-field Gaussian variational distribution, 𝔻KL(q~f~(𝐗;𝚯)p~f~(𝐗;𝚯))subscript𝔻KLconditionalsubscript~𝑞~𝑓𝐗𝚯subscript~𝑝~𝑓𝐗𝚯\mathbb{D}_{\textrm{KL}}(\smash{\tilde{q}}_{\smash{\tilde{f}}(\mathbf{X};\bm{% \Theta})}\,\|\,\smash{\tilde{p}}_{\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})})blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ∥ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ) is a KL divergence between two multivariate Gaussians and can be obtained analytically.

Using this approximation, we obtain an estimator of the variational objective given by

~(q𝚯)=˙𝔼qf(𝐗𝒟;𝚯)[logp𝐲|f(𝐗;𝚯)(𝐲𝒟|f(𝐗𝒟;𝜽))]sup𝐗𝒳𝔻KL(q~f~(𝐗;𝚯)p~f~(𝐗;𝚯)),~subscript𝑞𝚯˙subscript𝔼subscript𝑞𝑓subscript𝐗𝒟𝚯delimited-[]subscript𝑝conditional𝐲𝑓𝐗𝚯conditionalsubscript𝐲𝒟𝑓subscript𝐗𝒟𝜽subscriptsupremum𝐗subscript𝒳subscript𝔻KLconditionalsubscript~𝑞~𝑓𝐗𝚯subscript~𝑝~𝑓𝐗𝚯\displaystyle\begin{split}&\tilde{\mathcal{F}}(q_{\bm{\Theta}})\,\dot{=}\,% \mathbb{E}_{q_{f(\mathbf{X}_{\mathcal{D}};\bm{\Theta})}}[\log p_{\mathbf{y}|f(% \mathbf{X};\bm{\Theta})}(\mathbf{y}_{\mathcal{D}}\,|\,f(\mathbf{X}_{\mathcal{D% }};{\bm{\theta}}))]-\sup_{\mathbf{X}\in\mathcal{X}_{\mathbb{N}}}\mathbb{D}_{% \textrm{KL}}(\smash{\tilde{q}}_{\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})}\,\|% \,\smash{\tilde{p}}_{\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})}),\end{split}start_ROW start_CELL end_CELL start_CELL over~ start_ARG caligraphic_F end_ARG ( italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT ) over˙ start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ; bold_Θ ) end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ; bold_italic_θ ) ) ] - roman_sup start_POSTSUBSCRIPT bold_X ∈ caligraphic_X start_POSTSUBSCRIPT blackboard_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ∥ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ) , end_CELL end_ROW (15)

where the arguments of the KL divergence have been replaced by the (locally accurate) approximations to the variational and prior distributions over functions evaluated at 𝐗𝐗\mathbf{X}bold_X, respectively. Since the stochastic functions f~(;𝚯)~𝑓𝚯\smash{\tilde{f}}(\cdot\,;\bm{\Theta})over~ start_ARG italic_f end_ARG ( ⋅ ; bold_Θ ) induced by q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT and p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT under the linearized mapping will be closer to the stochastic function under f𝑓fitalic_f the smaller the variance of q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT and p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT, respectively, the approximation to the KL divergence will be more accurate the smaller the variance of q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT and p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT.

Next, we turn to computing the supremum. Unlike Sun et al. (2019), who consider the supremum as a separate optimization problem, we do not seek to compute the supremum by searching over points 𝐗𝒳𝐗subscript𝒳\mathbf{X}\in\mathcal{X}_{\mathbb{N}}bold_X ∈ caligraphic_X start_POSTSUBSCRIPT blackboard_N end_POSTSUBSCRIPT but instead propose to estimate the supremum at every gradient step via a simple finite-sample estimator. Specifically, letting I(𝐗)=˙𝔻KL(q~f~(𝐗;𝚯)p~f~(𝐗;𝚯))𝐼𝐗˙subscript𝔻KLconditionalsubscript~𝑞~𝑓𝐗𝚯subscript~𝑝~𝑓𝐗𝚯I(\mathbf{X})\,\dot{=}\,\mathbb{D}_{\textrm{KL}}(\smash{\tilde{q}}_{\smash{% \tilde{f}}(\mathbf{X};\bm{\Theta})}\,\|\,\smash{\tilde{p}}_{\smash{\tilde{f}}(% \mathbf{X};\bm{\Theta})})italic_I ( bold_X ) over˙ start_ARG = end_ARG blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ∥ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ), we estimate G=sup𝐗𝒳I(𝐗)𝐺subscriptsupremum𝐗subscript𝒳𝐼𝐗G=\sup_{\mathbf{X}\in\mathcal{X}_{\mathbb{N}}}I(\mathbf{X})italic_G = roman_sup start_POSTSUBSCRIPT bold_X ∈ caligraphic_X start_POSTSUBSCRIPT blackboard_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_I ( bold_X ) using the Monte Carlo estimator

G^(𝒳𝒞S)=max𝐗𝒳𝒞SI(𝐗),^𝐺superscriptsubscript𝒳𝒞𝑆subscript𝐗superscriptsubscript𝒳𝒞𝑆𝐼𝐗\displaystyle\SwapAboveDisplaySkip\hat{G}(\mathcal{X}_{\mathcal{C}}^{S})=\max_% {\mathbf{X}\in\mathcal{X}_{\mathcal{C}}^{S}}I(\mathbf{X}),over^ start_ARG italic_G end_ARG ( caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ) = roman_max start_POSTSUBSCRIPT bold_X ∈ caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_I ( bold_X ) , (16)

where 𝒳𝒞S=˙{𝐗𝒞(i)}i=1Ssuperscriptsubscript𝒳𝒞𝑆˙superscriptsubscriptsuperscriptsubscript𝐗𝒞𝑖𝑖1𝑆\mathcal{X}_{\mathcal{C}}^{S}\,\dot{=}\,\{\mathbf{X}_{\mathcal{C}}^{(i)}\}_{i=% 1}^{S}caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT over˙ start_ARG = end_ARG { bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT is a collection of S𝑆Sitalic_S sets of context points 𝐗𝒞(i)=˙{𝐱(j)}j=1Ksuperscriptsubscript𝐗𝒞𝑖˙superscriptsubscriptsuperscript𝐱𝑗𝑗1𝐾\mathbf{X}_{\mathcal{C}}^{(i)}\,\dot{=}\,\{\mathbf{x}^{(j)}\}_{j=1}^{K}bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT over˙ start_ARG = end_ARG { bold_x start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT jointly sampled from a context distribution p𝒳𝒞subscript𝑝subscript𝒳𝒞p_{\mathcal{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT. Each context set 𝐗𝒞(i)superscriptsubscript𝐗𝒞𝑖\mathbf{X}_{\mathcal{C}}^{(i)}bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT can be viewed as a single Monte Carlo sample from the input space so that the estimator G^(𝒳𝒞S)^𝐺superscriptsubscript𝒳𝒞𝑆\hat{G}(\mathcal{X}_{\mathcal{C}}^{S})over^ start_ARG italic_G end_ARG ( caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ) provides an S𝑆Sitalic_S-sample Monte Carlo estimate of the supremum. While this estimator is crude and only provides a rough approximation to the true supremum, it encourages the variational distribution over functions to match the prior distribution over functions on the sets of context points. The choice of the context distribution p𝒳𝒞subscript𝑝subscript𝒳𝒞p_{\mathcal{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT can be informed by knowledge about the prediction task and should be viewed as a problem-specific modeling choice. Similarly, the numbers of samples S𝑆Sitalic_S and K𝐾Kitalic_K are hyperparameters to be optimized with a validation set. For details on how p𝒳𝒞subscript𝑝subscript𝒳𝒞p_{\mathcal{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT is chosen for the empirical evaluation in Section 5, see Appendix D.

3.3 Stochastic Estimation of the Approximate Function-Space Variational Objective

Let q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT be a Gaussian mean-field variational distribution, let p𝚯subscript𝑝𝚯p_{\bm{\Theta}}italic_p start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT be an isotropic Gaussian prior, let (𝐗,𝐲)subscript𝐗subscript𝐲(\mathbf{X}_{\mathcal{B}},\mathbf{y}_{\mathcal{B}})( bold_X start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT , bold_y start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ) be a mini-batch of the training data, and reparameterize 𝚯𝚯\bm{\Theta}bold_Θ as 𝚯^(𝝁,𝚺,ϵ(j))=˙𝝁+𝚺ϵ(j)^𝚯𝝁𝚺superscriptbold-italic-ϵ𝑗˙𝝁direct-product𝚺superscriptbold-italic-ϵ𝑗\hat{\bm{\Theta}}({\bm{\mu}},\bm{\Sigma},\bm{\epsilon}^{(j)})\,\dot{=}\,{\bm{% \mu}}+\bm{\Sigma}\odot\bm{\epsilon}^{(j)}over^ start_ARG bold_Θ end_ARG ( bold_italic_μ , bold_Σ , bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ) over˙ start_ARG = end_ARG bold_italic_μ + bold_Σ ⊙ bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT. Using the estimator G^(𝒳𝒞S)^𝐺superscriptsubscript𝒳𝒞𝑆\hat{G}(\mathcal{X}_{\mathcal{C}}^{S})over^ start_ARG italic_G end_ARG ( caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ) defined above and estimating the expected log-likelihood via Monte Carlo sampling, we obtain a Monte Carlo estimator for the function-space variational objective:

¯(𝝁,𝚺)=1Mj=1Mlogp𝐲|f(𝐗;𝚯)(𝐲|f(𝐗;𝚯^(𝝁,𝚺,ϵ(j))))max𝐗𝒳𝒞S𝔻KL(q~f~(𝐗;𝚯^)p~f~(𝐗;𝚯^))¯𝝁𝚺1𝑀superscriptsubscript𝑗1𝑀subscript𝑝conditional𝐲𝑓𝐗𝚯conditionalsubscript𝐲𝑓subscript𝐗^𝚯𝝁𝚺superscriptbold-italic-ϵ𝑗subscript𝐗superscriptsubscript𝒳𝒞𝑆subscript𝔻KLconditionalsubscript~𝑞~𝑓𝐗^𝚯subscript~𝑝~𝑓𝐗^𝚯\displaystyle\bar{\mathcal{F}}({\bm{\mu}},\bm{\Sigma})=\frac{1}{M}\sum% \nolimits_{j=1}^{M}\log p_{\mathbf{y}|f(\mathbf{X};\bm{\Theta})}(\mathbf{y}_{% \mathcal{B}}\,|\,f(\mathbf{X}_{\mathcal{B}};\hat{\bm{\Theta}}({\bm{\mu}},\bm{% \Sigma},\bm{\epsilon}^{(j)})))-\max_{\mathbf{X}\in\mathcal{X}_{\mathcal{C}}^{S% }}{\mathbb{D}_{\textrm{KL}}(\smash{\tilde{q}}_{\smash{\tilde{f}}(\mathbf{X};% \hat{\bm{\Theta}})}\,\|\,\smash{\tilde{p}}_{\smash{\tilde{f}}(\mathbf{X};\hat{% \bm{\Theta}})})}over¯ start_ARG caligraphic_F end_ARG ( bold_italic_μ , bold_Σ ) = divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT | italic_f ( bold_X start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ; over^ start_ARG bold_Θ end_ARG ( bold_italic_μ , bold_Σ , bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ) ) ) - roman_max start_POSTSUBSCRIPT bold_X ∈ caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_POSTSUBSCRIPT blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; over^ start_ARG bold_Θ end_ARG ) end_POSTSUBSCRIPT ∥ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; over^ start_ARG bold_Θ end_ARG ) end_POSTSUBSCRIPT ) (17)

with ϵ(j)𝒩(𝟎,𝐈P)similar-tosuperscriptbold-italic-ϵ𝑗𝒩0subscript𝐈𝑃\bm{\epsilon}^{(j)}\sim\mathcal{N}(\mathbf{0},\mathbf{I}_{P})bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ∼ caligraphic_N ( bold_0 , bold_I start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ) and 𝒳𝒞Ssuperscriptsubscript𝒳𝒞𝑆\mathcal{X}_{\mathcal{C}}^{S}caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT as defined above. This Monte Carlo estimator is biased due to the linearization and context-set approximations but allows for scalable gradient-based stochastic optimization.

Selection of Prior.  For all experiments that involve uncertainty quantification, we chose a prior distribution over parameters that induces a prior distribution over functions pf(;𝚯)subscript𝑝𝑓𝚯p_{f(\cdot;\bm{\Theta})}italic_p start_POSTSUBSCRIPT italic_f ( ⋅ ; bold_Θ ) end_POSTSUBSCRIPT and a prior predictive distribution that exhibits a high degree of predictive uncertainty at evaluation points from regions in input space where p𝒳𝒞subscript𝑝subscript𝒳𝒞p_{\mathcal{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT has non-zero support and, under smoothness constraints, on evaluation points in nearby regions. For settings where prior information is encoded in data—for example, in the form of expert demonstrations of robotic manipulation tasks (Rudner et al., 2021) or in the form of pre-trained networks in continual or transfer learning (Rudner et al., 2022)—an empirical prior that reflects this information can be specified. For further details, see Appendix D.

Selection of Context Distribution.  The distribution p𝒳𝒞subscript𝑝subscript𝒳𝒞p_{\mathcal{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT caligraphic_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT allows us to incorporate information about the data-generating process into training and encourage the variational distribution to match the prior over functions in relevant parts of the input space. By taking advantage of the abundance of data available in real-world settings, context distributions can be constructed from large datasets like ImageNet (Krizhevsky et al., 2012), from small but diverse datasets like CIFAR-100, or by using any set of task-related unlabeled data. In our experiments, we choose two types of context distributions. One of the context distributions is constructed from the training data and only contains randomly sampled monochrome images, and one is constructed from a real-world dataset generated from a data distribution related to that of the training data. For example, when training on FashionMNIST, we use KMNIST as the context distribution, and when training on CIFAR-10, we use CIFAR-100 as the context distribution. For further details, see Appendix D.

Posterior Predictive Distribution.  After optimizing the variational objective with respect to the parameters of the variational distribution q𝚯subscript𝑞𝚯q_{\bm{\Theta}}italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT, we use the fact that we can obtain function draws by sampling from the distribution over parameters to obtain an approximate posterior predictive distribution

q(𝐲|𝐱)=p(𝐲|f(𝐱;𝜽))qf(𝐱;𝚯)df(𝐱;𝜽)1Mj=1Mp(𝐲|f(𝐱;𝚯(j)))with𝚯(j)q𝚯,formulae-sequence𝑞conditionalsubscript𝐲subscript𝐱𝑝conditionalsubscript𝐲𝑓subscript𝐱𝜽subscript𝑞𝑓subscript𝐱𝚯d𝑓subscript𝐱𝜽1subscript𝑀superscriptsubscript𝑗1subscript𝑀𝑝conditionalsubscript𝐲𝑓subscript𝐱superscript𝚯𝑗similar-towithsuperscript𝚯𝑗subscript𝑞𝚯\displaystyle\begin{split}q(\mathbf{y}_{\ast}\,|\,\mathbf{x}_{\ast})&=\int p(% \mathbf{y}_{\ast}\,|\,f(\mathbf{x}_{\ast};{\bm{\theta}}))\,q_{f(\mathbf{x}_{% \ast};\bm{\Theta})}\,\,\textrm{d}f(\mathbf{x}_{\ast};{\bm{\theta}})\\ &\approx\frac{1}{M_{\ast}}\sum\nolimits_{j=1}^{M_{\ast}}p(\mathbf{y}_{\ast}\,|% \,f(\mathbf{x}_{\ast};\bm{\Theta}^{(j)}))\quad\text{with}\quad\bm{\Theta}^{(j)% }\sim q_{\bm{\Theta}},\end{split}start_ROW start_CELL italic_q ( bold_y start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT | bold_x start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ) end_CELL start_CELL = ∫ italic_p ( bold_y start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT | italic_f ( bold_x start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ; bold_italic_θ ) ) italic_q start_POSTSUBSCRIPT italic_f ( bold_x start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ; bold_Θ ) end_POSTSUBSCRIPT d italic_f ( bold_x start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ; bold_italic_θ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≈ divide start_ARG 1 end_ARG start_ARG italic_M start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( bold_y start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT | italic_f ( bold_x start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ; bold_Θ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ) ) with bold_Θ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ∼ italic_q start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT , end_CELL end_ROW (18)

where Msubscript𝑀M_{\ast}italic_M start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT is the number of Monte Carlo samples used to estimate the predictive distribution.

4 Related Work

There is a growing body of work on function-space approaches to inference in bnns, deep learning, and applications such as continual learning (Benjamin et al., 2019; Sun et al., 2019; Titsias et al., 2020; Burt et al., 2021; Pan et al., 2020; Ma and Hernández-Lobato, 2021; Rudner et al., 2022).

Function-Space Inference in Bayesian Neural Networks.

Previously proposed methods for fsvi in bnns are based on approximate gradient estimators and either replace the supremum in Equation 7 with an expectation (Sun et al., 2019) or do not define an explicit variational objective (Wang et al., 2019). Sun et al. (2019) and Carvalho et al. (2020) use Gaussian process priors over functions for which the function-space variational inference problem is not well-defined (see Section 2.1 and Burt et al. (2021)). More recent work has attempted to circumvent the intractability of the variational objective in  Equation 6 by proposing alternative objectives for function-space inference in bnn(Ma et al., 2019; Ober and Aitchison, 2020; Ma and Hernández-Lobato, 2021). Rudner et al. (2022) extend the approach presented in Section 3 to sequential inference problems and apply it to continual learning.

Linear Models.

Immer et al. (2020) and Khan et al. (2019) show that approximate bnn posterior distribution via the Laplace and Generalized-Gauss-Newton approximation corresponds to exact posteriors under linearizations of different models. Unlike in our approach, they use a Laplace approximation and do not perform variational inference and do not optimize the variance parameters. Furthermore, Immer et al. (2020) and Khan et al. (2019) use a neural network model to obtain a parameter maximum a posteriori estimate, but then use a linearization of the neural network model to compute a posterior predictive distribution. In contrast, our work only uses the linearization to obtain an estimator of the variational objective but uses the unlinearized model to construct a posterior predictive distribution.

Pathologies of Variational Inference in Bayesian Neural Networks.

Burt et al. (2021) consider the function-space variational objective in Equation 6 and show that the KL divergence between bnns with different networks architectures are not well-defined. A parallel line of research showed that posterior predictive distributions of shallow bnns with mean-field variational distributions have a limited ability to represent complex covariance structures in function space (Foong et al., 2019, 2020) but that deep bnns do not suffer from this limitation (Farquhar et al., 2020b). Our results are consistent with the findings of Farquhar et al. (2020b) that mean-field variational distributions are able to represent complex covariance structures in function space.

5 Empirical Evaluation

In this section, we evaluate fsvi on high-dimensional classification tasks that were out of reach for function-space variational inference methods proposed in prior works and compare fsvi to several well-established and state-of-the-art Bayesian deep learning and deterministic uncertainty quantification methods. We show that fsvi (sometimes significantly) outperforms existing Bayesian and non-Bayesian methods in terms of their in-distribution uncertainty calibration and out-of-distribution predictive uncertainty estimation. For a details on models, training and validation procedures, and datasets used, see Appendix D. For a comparison to Sun et al. (2019) on small-scale regression tasks, see Section B.2.

5.1 Predictive Performance, Uncertainty Estimation, and Distribution Shift Detection

In this set of experiments, we assess the reliability of the uncertainty estimates generated by fsvi. If a bnn trained via fsvi is able to perform reliable uncertainty estimation, its predictive uncertainty will be significantly higher on input points that were generated according to a different data-generating distribution than the training data. For models trained on the FashionMNIST dataset, we use the MNIST and NotMNIST datasets as out-of-distribution evaluation points, while for models trained on the CIFAR-10 dataset, we use the SVHN dataset as out-of-distribution evaluation points.

For models trained on either FashionMNIST or CIFAR-10, we evaluate their in-distribution performance in terms of test accuracy, test log-likelihood, and test calibration. To evaluate the quality of different models’ uncertainty estimates, we compute uncertainty estimates for the pairs FashionMNIST/MNIST, FashionMNIST/NotMNIST, and CIFAR-10/SVHN to and measure for a range of thresholds how well the datasets in each pair can be separated solely based on the uncertainty estimates. This experiment setup follows prior work by van Amersfoort et al. (2020) and Immer et al. (2020). We report the area under the receiver operating characteristic (ROC) curve in Tables 1 and 2.

Table 1: Comparison of in- and out-of-distribution performance metrics on FashionMNIST (mean ±plus-or-minus\pm± standard error over ten random seeds). The last two columns show the AUROC for binary in- vs. out-of-distribution detection on MNIST (M) and NotMNIST (NM). MNIST and NotMNIST are used as out-of-distribution datasets. Best overall results for single and ensemble models are printed in boldface with gray shading. Results within a 95959595% confidence interval of the best overall result are printed in boldface only. All methods use the same four-layer CNN architecture. For further details about model architectures and training and evaluation protocols, see Appendix D.
Method Accuracy \uparrow ECE \downarrow AUROC M \uparrow AUROC NM \uparrow
map 91.73±0.08plus-or-minus0.08{\scriptstyle\pm 0.08}± 0.08 0.037±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 87.00±0.30plus-or-minus0.30{\scriptstyle\pm 0.30}± 0.30 74.85±1.31plus-or-minus1.31{\scriptstyle\pm 1.31}± 1.31
mfvi (Blundell et al., 2015) 91.03±0.04plus-or-minus0.04{\scriptstyle\pm 0.04}± 0.04 0.038±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 93.10±0.34plus-or-minus0.34{\scriptstyle\pm 0.34}± 0.34 88.88±0.74plus-or-minus0.74{\scriptstyle\pm 0.74}± 0.74
mfvi (tempered) 91.38±0.05plus-or-minus0.05{\scriptstyle\pm 0.05}± 0.05 0.058±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 86.30±0.29plus-or-minus0.29{\scriptstyle\pm 0.29}± 0.29 80.78±0.68plus-or-minus0.68{\scriptstyle\pm 0.68}± 0.68
mfvi (radial) (Farquhar et al., 2020a) 90.31±0.11plus-or-minus0.11{\scriptstyle\pm 0.11}± 0.11 0.035±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 84.40±0.68plus-or-minus0.68{\scriptstyle\pm 0.68}± 0.68 82.11±1.15plus-or-minus1.15{\scriptstyle\pm 1.15}± 1.15
mc dropout (Gal and Ghahramani, 2016) 90.55±0.04plus-or-minus0.04{\scriptstyle\pm 0.04}± 0.04 0.0120.0120.0120.012±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 88.46±0.57plus-or-minus0.57{\scriptstyle\pm 0.57}± 0.57 80.02±1.04plus-or-minus1.04{\scriptstyle\pm 1.04}± 1.04
swag (Maddox et al., 2019) 92.56±0.05plus-or-minus0.05{\scriptstyle\pm 0.05}± 0.05 0.043±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 85.18±0.35plus-or-minus0.35{\scriptstyle\pm 0.35}± 0.35 80.31±0.30plus-or-minus0.30{\scriptstyle\pm 0.30}± 0.30
duq (van Amersfoort et al., 2020) 92.4092.4092.4092.40±0.20plus-or-minus0.20{\scriptstyle\pm 0.20}± 0.20 -- 95.50±0.70plus-or-minus0.70{\scriptstyle\pm 0.70}± 0.70 94.60±1.80plus-or-minus1.80{\scriptstyle\pm 1.80}± 1.80
bnn-laplace (Immer et al., 2020) 92.25±0.10plus-or-minus0.10{\scriptstyle\pm 0.10}± 0.10 0.012±0.003plus-or-minus0.0120.0030.012{\scriptstyle\pm 0.003}0.012 ± 0.003 95.55±0.60plus-or-minus0.60{\scriptstyle\pm 0.60}± 0.60      --
spg (Ma and Hernández-Lobato, 2021) 91.60±0.14plus-or-minus0.14{\scriptstyle\pm 0.14}± 0.14      -- 95.60±6.00plus-or-minus6.00{\scriptstyle\pm 6.00}± 6.00      --
fsvi (p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT = random monochrome) 93.13±0.13plus-or-minus93.130.1393.13{\scriptstyle\pm 0.13}93.13 ± 0.13 0.012±0.002plus-or-minus0.0120.0020.012{\scriptstyle\pm 0.002}0.012 ± 0.002 96.23±0.46plus-or-minus96.230.46{96.23}{\scriptstyle\pm 0.46}96.23 ± 0.46 95.02±0.69plus-or-minus95.020.69{95.02}{\scriptstyle\pm 0.69}95.02 ± 0.69
fsvi (p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT = KMNIST) 93.48±0.12plus-or-minus93.480.12\mathbf{93.48}{\scriptstyle\pm 0.12}bold_93.48 ± 0.12 0.0100.010\mathbf{0.010}bold_0.010±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 99.80±0.20plus-or-minus99.800.20\mathbf{99.80}{\scriptstyle\pm 0.20}bold_99.80 ± 0.20 97.26±0.23plus-or-minus97.260.23\mathbf{97.26}{\scriptstyle\pm 0.23}bold_97.26 ± 0.23
Deep Ensemble 92.49±0.01plus-or-minus0.01{\scriptstyle\pm 0.01}± 0.01 0.0190.019\mathbf{0.019}bold_0.019±0.000plus-or-minus0.000{\scriptstyle\pm 0.000}± 0.000 89.22±0.09plus-or-minus0.09{\scriptstyle\pm 0.09}± 0.09 83.17±0.91plus-or-minus0.91{\scriptstyle\pm 0.91}± 0.91
fsvi Ensemble (p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT = random monochrome) 94.4494.44\mathbf{94.44}bold_94.44±0.07plus-or-minus0.07{\scriptstyle\pm 0.07}± 0.07 0.020±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 97.8597.85\mathbf{97.85}bold_97.85±0.15plus-or-minus0.15{\scriptstyle\pm 0.15}± 0.15 96.95±0.20plus-or-minus96.950.20\mathbf{96.95}{\scriptstyle\pm 0.20}bold_96.95 ± 0.20

Predictive Performance and Calibration.  To assess in-distribution predictive performance and calibration, we report the test accuracy, negative log-likelihood (NLL), and expected calibration error (ECE) for models trained on FashionMNIST and CIFAR-10 in Tables 1 and 2. On both FashionMNIST and CIFAR-10, fsvi achieves the lowest NLL and either the best or second-best predictive accuracy and ECE, respectively, across all methods. Notably, fsvi significantly outperforms spg (Ma and Hernández-Lobato, 2021), an alternative function-space variational inference method.

Predictive Uncertainty under Distribution Shift.  In Tables 1 and 2, we report evaluation metrics that elucidate the reliability of different methods’ predictive uncertainty under distribution shift. fsvi exhibits reliable predictive uncertainty estimates that allow distinguishing between in- and out-of-distribution inputs with high accuracy. As would be expected, we observe that using context distributions that reflect our knowledge about the data-generating process can significantly improve uncertainty quantification under fsvi. For the FashionMNIST experiment, we used the KMNIST dataset, which contains grayscale images of Kuzushiji letters, and for the CIFAR-10 experiment, we used the CIFAR-100 dataset, which contains RGB images of 100 classes. Both KMNIST and CIFAR-100 differ from the OOD datasets (MNIST and NotMNIST and SVHN, respectively) used to compute OOD-AUROC metrics in Tables 1 and 2, but using them as context distributions significantly increased the ability of bnns trained via fsvi to identify distributionally shifted samples. Since the variational objective encourages matching the prior (which we chose to have high variance) on samples from the context distribution can improve uncertainty estimation in regions of the input space far from the training data.

Table 2: Comparison of in- and out-of-distribution performance metrics on CIFAR-10 (mean ±plus-or-minus\pm± standard error over ten random seeds). SVHN and corrupted CIFAR-10 (C-CIFAR) are used as an out-of-distribution datasets. The penultimate column shows the AUROC for binary in- vs. out-of-distribution detection on SVHN. Best overall results for single and ensemble models are printed in boldface with gray shading. Results within a 95959595% confidence interval of the best overall result are printed in boldface only. All methods use a ResNet-18 architecture. For further details about model architectures and training and evaluation protocols, see Appendix D.
Method Accuracynormal-↑\uparrow ECEnormal-↓\downarrow OOD-AUROCnormal-↑\uparrow C-CIFAR Accnormal-↑\uparrow
map 93.19±0.11plus-or-minus0.11{\scriptstyle\pm 0.11}± 0.11 0.043±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 94.65±0.27plus-or-minus0.27{\scriptstyle\pm 0.27}± 0.27 78.87±1.39plus-or-minus1.39{\scriptstyle\pm 1.39}± 1.39
mfvi (Blundell et al., 2015) 89.98±0.09plus-or-minus0.09{\scriptstyle\pm 0.09}± 0.09 0.040±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 92.14±0.34plus-or-minus0.34{\scriptstyle\pm 0.34}± 0.34 79.36±1.35plus-or-minus1.35{\scriptstyle\pm 1.35}± 1.35
mfvi (tempered) 90.87±0.11plus-or-minus0.11{\scriptstyle\pm 0.11}± 0.11 0.048±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 91.82±0.90plus-or-minus0.90{\scriptstyle\pm 0.90}± 0.90 79.86±1.32plus-or-minus1.32{\scriptstyle\pm 1.32}± 1.32
mc dropout (Gal and Ghahramani, 2016) 93.55±0.07plus-or-minus0.07{\scriptstyle\pm 0.07}± 0.07 0.040±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 92.44±0.57plus-or-minus0.57{\scriptstyle\pm 0.57}± 0.57 80.13±1.37plus-or-minus1.37{\scriptstyle\pm 1.37}± 1.37
swag (Maddox et al., 2019) 93.13±0.14plus-or-minus0.14{\scriptstyle\pm 0.14}± 0.14 0.067±0.002plus-or-minus0.002{\scriptstyle\pm 0.002}± 0.002 89.79±0.50plus-or-minus0.50{\scriptstyle\pm 0.50}± 0.50 76.12±0.51plus-or-minus0.51{\scriptstyle\pm 0.51}± 0.51
vogn (Osawa et al., 2019) 84.27±0.20plus-or-minus0.20{\scriptstyle\pm 0.20}± 0.20 0.040±0.002plus-or-minus0.002{\scriptstyle\pm 0.002}± 0.002 87.60±0.20plus-or-minus0.20{\scriptstyle\pm 0.20}± 0.20      --
duq (van Amersfoort et al., 2020) 94.1094.10\mathbf{94.10}bold_94.10±0.20plus-or-minus0.20{\scriptstyle\pm 0.20}± 0.20 -- 92.70±1.30plus-or-minus1.30{\scriptstyle\pm 1.30}± 1.30      --
spg (Ma and Hernández-Lobato, 2021) 77.69±0.64plus-or-minus0.64{\scriptstyle\pm 0.64}± 0.64      -- 88.30±4.00plus-or-minus4.00{\scriptstyle\pm 4.00}± 4.00      --
fsvi (p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT = random monochrome) 93.3593.3593.3593.35±0.04plus-or-minus0.04{\scriptstyle\pm 0.04}± 0.04 0.034±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 94.7694.7694.7694.76±0.24plus-or-minus0.24{\scriptstyle\pm 0.24}± 0.24 80.8180.81\mathbf{80.81}bold_80.81±0.43plus-or-minus0.43{\scriptstyle\pm 0.43}± 0.43
fsvi (p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT = CIFAR-100) 93.5793.5793.5793.57±0.04plus-or-minus0.04{\scriptstyle\pm 0.04}± 0.04 0.0260.026\mathbf{0.026}bold_0.026±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 98.0798.07\mathbf{98.07}bold_98.07±0.10plus-or-minus0.10{\scriptstyle\pm 0.10}± 0.10 81.2081.20\mathbf{81.20}bold_81.20±0.42plus-or-minus0.42{\scriptstyle\pm 0.42}± 0.42
Deep Ensemble 95.1395.1395.1395.13±0.06plus-or-minus0.06{\scriptstyle\pm 0.06}± 0.06 0.019±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 98.0498.04\mathbf{98.04}bold_98.04±0.07plus-or-minus0.07{\scriptstyle\pm 0.07}± 0.07 81.2281.22\mathbf{81.22}bold_81.22±0.37plus-or-minus0.37{\scriptstyle\pm 0.37}± 0.37
fsvi Ensemble (p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT = random monochrome) 95.1995.19\mathbf{95.19}bold_95.19±0.03plus-or-minus0.03{\scriptstyle\pm 0.03}± 0.03 0.013±0.001plus-or-minus0.001{\scriptstyle\pm 0.001}± 0.001 99.1999.19\mathbf{99.19}bold_99.19±0.41plus-or-minus0.41{\scriptstyle\pm 0.41}± 0.41 81.3581.35\mathbf{81.35}bold_81.35±0.48plus-or-minus0.48{\scriptstyle\pm 0.48}± 0.48

5.2 Generalization and Reliability of Predictive Uncertainty under Distribution Shift

To assess the reliability of predictive models in deep learning, Ovadia et al. (2019) propose the following desiderata: In order for a model to be considered reliable, it ought to (i) exhibit low predictive uncertainty on training data and high predictive uncertainty on out-of-distribution inputs, (ii) generate predictive uncertainty estimates that allow distinguishing in- from out-of-distribution inputs, and (iii) if possible, maintain high predictive accuracy even under distribution shift. Models that satisfy these desiderata are less likely to make poor, high-confidence predictions and more amenable for use in safety-critical downstream tasks.

Refer to caption
Figure 2: Predictive uncertainty and accuracy on rotated MNIST. Models with reliable uncertainty estimates would exhibit higher predictive uncertainty the more the digits are rotated. Ideally, such models would maintain high predictive accuracy (low Brier score).

To illustrate these desiderata, we follow Ovadia et al. (2019) and consider the rotated MNIST task, where a model is trained on MNIST and evaluated on rotated MNIST digits. The goal is to maintain a high level of predictive accuracy (measured in terms of Brier scores) while exhibiting an increasing level of predictive uncertainty on distribution shifts of increasing magnitude. Figure 2 shows Brier scores (lower is better) and predictive entropy estimates (higher means more uncertain) of four different models. As rotating the MNIST digits gradually shifts the data distributions, we would expect Brier scores to increase (corresponding to worse predictive accuracy) as the rotation angle increases. A model with reliable predictive entropy estimates would only experience a small decrease under distribution shift while exhibiting a large increase in predictive uncertainty. As can be seen in the plot, the Brier scores of fsvi decreases the least, while fsvi’s uncertainty is significantly higher than other models’. To assess the reliability of different uncertainty quantification methods on a more challenging distribution-shift task, we consider corrupted CIFAR-10 inputs under the second-mildest corruption level used in (Ovadia et al., 2019) and report our results in Table 2. Consistent with the rotated MNIST results, fsvi achieves the highest accuracy on the corrupted data.

5.3 Safety-Critical Uncertainty-Aware Selective Prediction: Diabetic Retinopathy Diagnosis

Refer to caption
Refer to caption
Figure 3: Retina scan examples. Top: healthy. Bottom: unhealthy.

To evaluate the reliability of the predictive uncertainty of fsvi in a real-world safety-critical setting, we consider the task of diagnosing diabetic retinopathy (DR), a medical condition that can lead to impaired vision, from retina scans (Leibig et al., 2017; Filos et al., 2019; Band et al., 2021). We use two publicly available datasets, EyePACS (2015) and APTOS (2019), each containing RGB images of a human retina graded by a medical expert on the following scale: 0 (no DR), 1 (mild DR), 2 (moderate DR), 3 (severe DR), and 4 (proliferative DR). The Kaggle dataset was collected from patients in the United States, while the APTOS dataset was collected from patients in India using cheaper but more modern scanning devices. We follow Leibig et al. (2017)Filos et al. (2019), and Band et al. (2021) and binarize all examples from both the EyePACS and APTOS datasets by dividing the classes up into sight-threatening diabetic retinopathy—defined as moderate diabetic retinopathy or worse (classes {2,3,4}234\{2,3,4\}{ 2 , 3 , 4 })—and non-sight-threatening diabetic retinopathy—defined as no or mild diabetic retinopathy (classes {0,1}01\{0,1\}{ 0 , 1 }). This results in a binary prediction task.

To assess the reliability of predictive models when medical training and test data are obtained from different patient populations or collected with the same medical equipment, we follow Band et al. (2021) and use the Kaggle dataset for training and the distributionally shifted APTOS dataset for evaluation. The results are shown in Figure 4, which plot the ROC curves for the binary prediction problems as well as the area under the ROC curve for an uncertainty aware selective prediction task. For further details about the uncertainty-aware selective prediction evaluation protocol, see Section D.4. Figure 4 shows that fsvi performs well on all four tasks and is only outperformed by mc dropout. For full tabular results, see Section B.1.

Refer to caption
Refer to caption
(a) ROC: In-Domain
Refer to caption
(b) ROC: Country Shift
Refer to caption
(c) Selective Prediction AUROC: In-Domain
Refer to caption
(d) Selective Prediction AUROC: Country Shift
Figure 4: We jointly assess model predictive performance and uncertainty quantification on both in-domain and distributionally shifted data. Left: The receiver operating characteristic curve (ROC) for in-population diagnosis on the (aEyePACS (2015) test set and for (b) changing medical equipment and patient populations on the APTOS (2019) test set. The dot in black denotes the NHS-recommended 85% sensitivity and 80% specificity ratios (Widdowson, 2016). Right: Selective prediction on AUROC in (cEyePACS (2015) and (dAPTOS (2019) settings. Shading denotes standard error over six random seeds. See Section B.1 for tabular results.

6 Conclusion

The paper proposed a scalable and effective approach to function-space variational inference in bnns. We demonstrated that the proposed estimator of the function-space variational objective can be scaled up to high-dimensional data and large neural network architectures and that fsvi exhibits consistently reliable in- and out-of-distribution predictive performance on a wide range of datasets when compared to well-established and state-of-the-art uncertainty quantification methods. We hope that this work will lead to further research into function-space variational inference and the development of more sophisticated data-driven prior distributions over functions.

Acknowledgements

We thank Bryn Elesedy, Bobby He, and Andrew Jesson for feedback on an early draft of this paper. We thank Joost van Amersfoort for helpful discussions about experiment design and implementations. Tim G. J. Rudner is funded by the Rhodes Trust and the Engineering and Physical Sciences Research Council (EPSRC). We gratefully acknowledge donations of computing resources by the Alan Turing Institute.

References

  • Amodei et al. (2016) Dario Amodei, Chris Olah, Jacob Steinhardt, Paul Christiano, John Schulman, and Dan Mané. Concrete problems in ai safety, 2016.
  • APTOS (2019) APTOS. APTOS 2019 Blindness Detection Dataset, 2019.
  • Band et al. (2021) Neil Band, Tim G. J. Rudner, Qixuan Feng, Angelos Filos, Zachary Nado, Michael W. Dusenberry, Ghassen Jerfel, Dustin Tran, and Yarin Gal. Benchmarking Bayesian Deep Learning on Diabetic Retinopathy Detection Tasks. 2021.
  • Benjamin et al. (2019) Ari Benjamin, David Rolnick, and Konrad Kording. Measuring and regularizing networks in function space. In International Conference on Learning Representations, 2019.
  • Blundell et al. (2015) Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty in neural networks. volume 37 of Proceedings of Machine Learning Research, pages 1613–1622, Lille, France, 07–09 Jul 2015. PMLR.
  • Burt et al. (2021) David R. Burt, Sebastian W. Ober, Adrià Garriga-Alonso, and Mark van der Wilk. Understanding variational inference in function-space. In Third Symposium on Advances in Approximate Bayesian Inference, 2021.
  • Carvalho et al. (2020) Eduardo D. C. Carvalho, Ronald Clark, Andrea Nicastro, and Paul H. J. Kelly. Scalable uncertainty for computer vision with functional variational inference. In IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2020.
  • EyePACS (2015) EyePACS. Diabetic Retinopathy Detection Dataset, 2015.
  • Farquhar et al. (2020a) Sebastian Farquhar, Michael A. Osborne, and Yarin Gal. Radial Bayesian neural networks: Beyond discrete support in large-scale Bayesian deep learning. In Silvia Chiappa and Roberto Calandra, editors, Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, volume 108 of Proceedings of Machine Learning Research, pages 1352–1362. PMLR, 26–28 Aug 2020a.
  • Farquhar et al. (2020b) Sebastian Farquhar, Lewis Smith, and Yarin Gal. Liberty or depth: Deep Bayesian neural nets do not need complex weight posterior approximations. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020b.
  • Filos et al. (2019) Angelos Filos, Sebastian Farquhar, Aidan N. Gomez, Tim G. J. Rudner, Zachary Kenton, Lewis Smith, Milad Alizadeh, Arnoud de Kroon, and Yarin Gal. A systematic comparison of Bayesian deep learning robustness in diabetic retinopathy tasks, 2019.
  • Foong et al. (2019) Andrew Y. K. Foong, Yingzhen Li, José Miguel Hernández-Lobato, and Richard E. Turner. ’in-between’ uncertainty in Bayesian neural networks, 2019.
  • Foong et al. (2020) Andrew Y. K. Foong, David R. Burt, Yingzhen Li, and Richard E. Turner. On the expressiveness of approximate inference in Bayesian neural networks. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020.
  • Gal and Ghahramani (2016) Yarin Gal and Zoubin Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In Proceedings of the 33rd International Conference on International Conference on Machine Learning - Volume 48, ICML 2016, pages 1050–1059, 2016.
  • Graves (2011) Alex Graves. Practical variational inference for neural networks. In Proceedings of the 24th International Conference on Neural Information Processing Systems, NIPS’11, page 2348–2356, Red Hook, NY, USA, 2011. Curran Associates Inc. ISBN 9781618395993.
  • Hendrycks et al. (2021) Dan Hendrycks, Nicholas Carlini, John Schulman, and Jacob Steinhardt. Unsolved problems in ml safety, 2021.
  • Hinton and van Camp (1993) Geoffrey E. Hinton and Drew van Camp. Keeping the neural networks simple by minimizing the description length of the weights. In Proceedings of the Sixth Annual Conference on Computational Learning Theory, COLT ’93, page 5–13, New York, NY, USA, 1993. Association for Computing Machinery. ISBN 0897916115.
  • Hoffman et al. (2013) Matthew D. Hoffman, David M. Blei, Chong Wang, and John Paisley. Stochastic variational inference. Journal of Machine Learning Research, 14(1):1303–1347, May 2013. ISSN 1532-4435.
  • Immer et al. (2020) Alexander Immer, Maciej Korzepa, and Matthias Bauer. Improving predictions of Bayesian neural networks via local linearization, 2020.
  • Izmailov et al. (2020) Pavel Izmailov, Wesley J. Maddox, Polina Kirichenko, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Subspace inference for Bayesian deep learning. In Ryan P. Adams and Vibhav Gogate, editors, Proceedings of The 35th Uncertainty in Artificial Intelligence Conference, volume 115 of Proceedings of Machine Learning Research, pages 1169–1179. PMLR, 22–25 Jul 2020.
  • Jumper et al. (2021) John Jumper, Richard Evans, Alexander Pritzel, Tim Green, Michael Figurnov, Olaf Ronneberger, Kathryn Tunyasuvunakool, Russ Bates, Augustin Zidek, Anna Potapenko, Alex Bridgland, Clemens Meyer, Simon A A Kohl, Andrew J Ballard, Andrew Cowie, Bernardino Romera-Paredes, Stanislav Nikolov, Rishub Jain, Jonas Adler, Trevor Back, Stig Petersen, David Reiman, Ellen Clancy, Michal Zielinski, Martin Steinegger, Michalina Pacholska, Tamas Berghammer, Sebastian Bodenstein, David Silver, Oriol Vinyals, Andrew W Senior, Koray Kavukcuoglu, Pushmeet Kohli, and Demis Hassabis. Highly accurate protein structure prediction with AlphaFold. Nature, 596(7873):583–589, 2021. doi: 10.1038/s41586-021-03819-2.
  • Khan et al. (2019) Mohammad Emtiyaz E Khan, Alexander Immer, Ehsan Abedi, and Maciej Korzepa. Approximate inference turns deep networks into Gaussian processes. In Advances in Neural Information Processing Systems 32, pages 3094–3104. Curran Associates, Inc., 2019.
  • Krizhevsky et al. (2012) Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems 25:, pages 1106–1114, 2012.
  • Leibig et al. (2017) Christian Leibig, Vaneeda Allken, Murat Seçkin Ayhan, Philipp Berens, and Siegfried Wahl. Leveraging Uncertainty Information From Deep Neural Networks for Disease Detection. Nature Scientific Reports, 7(1):17816, 2017.
  • Ma and Hernández-Lobato (2021) Chao Ma and José Miguel Hernández-Lobato. Functional variational inference based on stochastic process generators. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
  • Ma et al. (2019) Chao Ma, Yingzhen Li, and Jose Miguel Hernandez-Lobato. Variational implicit processes. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 4222–4233. PMLR, 09–15 Jun 2019.
  • MacKay (1992) David J. C. MacKay. A practical Bayesian framework for backpropagation networks. Neural Comput., 4(3):448–472, May 1992. ISSN 0899-7667. doi: 10.1162/neco.1992.4.3.448.
  • Maddox et al. (2019) Wesley J Maddox, Pavel Izmailov, Timur Garipov, Dmitry P Vetrov, and Andrew Gordon Wilson. A simple baseline for Bayesian uncertainty in deep learning. In Advances in Neural Information Processing Systems, pages 13153–13164, 2019.
  • Matthews et al. (2016) Alexander G. de G. Matthews, James Hensman, Richard Turner, and Zoubin Ghahramani. On sparse variational methods and the Kullback-Leibler divergence between stochastic processes. volume 51 of Proceedings of Machine Learning Research, pages 231–239, Cadiz, Spain, 09–11 May 2016. PMLR.
  • Mnih et al. (2013) Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. Playing atari with deep reinforcement learning. In NIPS Deep Learning Workshop. 2013.
  • Neal (1996) Radford M Neal. Bayesian Learning for Neural Networks. 1996.
  • Ober and Aitchison (2020) Sebastian W. Ober and Laurence Aitchison. Global inducing point variational posteriors for Bayesian neural networks and deep Gaussian processes, 2020.
  • Osawa et al. (2019) Kazuki Osawa, Siddharth Swaroop, Mohammad Emtiyaz E Khan, Anirudh Jain, Runa Eschenhagen, Richard E Turner, and Rio Yokota. Practical deep learning with Bayesian principles. In Advances in Neural Information Processing Systems, volume 32, pages 4287–4299. Curran Associates, Inc., 2019.
  • Ovadia et al. (2019) Yaniv Ovadia, Emily Fertig, Jie Ren, Zachary Nado, D. Sculley, Sebastian Nowozin, Joshua Dillon, Balaji Lakshminarayanan, and Jasper Snoek. Can you trust your model’s uncertainty? Evaluating predictive uncertainty under dataset shift. In Advances in Neural Information Processing Systems 32. 2019.
  • Pan et al. (2020) Pingbo Pan, Siddharth Swaroop, Alexander Immer, Runa Eschenhagen, Richard E. Turner, and Mohammad Emtiyaz Khan. Continual deep learning by functional regularisation of memorable past. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020.
  • Polyanskiy and Wu (2017) Yury Polyanskiy and Yihong Wu. Strong data-processing inequalities for channels and Bayesian networks. In Eric Carlen, Mokshay Madiman, and Elisabeth M. Werner, editors, Convexity and Concentration, pages 211–249, New York, NY, 2017. Springer New York. ISBN 978-1-4939-7005-6.
  • Rudner and Toner (2021a) Tim G. J. Rudner and Helen Toner. Key Concepts in AI Safety: An Overview. In CSET Issue Briefs, 2021a.
  • Rudner and Toner (2021b) Tim G. J. Rudner and Helen Toner. Key Concepts in AI Safety: Robustness and Adversarial Examples. In CSET Issue Briefs, 2021b.
  • Rudner et al. (2021) Tim G. J. Rudner, Cong Lu, Michael A. Osborne, Yarin Gal, and Yee Whye Teh. On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations. In Advances in Neural Information Processing Systems 34, 2021.
  • Rudner et al. (2022) Tim G. J. Rudner, Freddie Bickford Smith, Qixuan Feng, Yee Whye Teh, and Yarin Gal. Continual Learning via Sequential Function-Space Variational Inference. In Proceedings of the 38th International Conference on Machine Learning, Proceedings of Machine Learning Research. PMLR, 2022.
  • Schervish (1995) M. J. Schervish. Theory of Statistics. Springer-Verlag, New York, NY, 1995.
  • Silver et al. (2016) David Silver, Aja Huang, Chris J. Maddison, Arthur Guez, Laurent Sifre, George van den Driessche, Julian Schrittwieser, Ioannis Antonoglou, Veda Panneershelvam, Marc Lanctot, Sander Dieleman, Dominik Grewe, John Nham, Nal Kalchbrenner, Ilya Sutskever, Timothy Lillicrap, Madeleine Leach, Koray Kavukcuoglu, Thore Graepel, and Demis Hassabis. Mastering the game of Go with deep neural networks and tree search. 529, 2016.
  • Snelson and Ghahramani (2006) Edward Snelson and Zoubin Ghahramani. Sparse Gaussian processes using pseudo-inputs. In Y. Weiss, B. Schölkopf, and J. C. Platt, editors, Advances in Neural Information Processing Systems 18, pages 1257–1264. MIT Press, 2006.
  • Sun et al. (2019) Shengyang Sun, Guodong Zhang, Jiaxin Shi, and Roger B. Grosse. Functional variational Bayesian neural networks. In 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net, 2019.
  • Titsias et al. (2020) Michalis K. Titsias, Jonathan Schwarz, Alexander G. de G. Matthews, Razvan Pascanu, and Yee Whye Teh. Functional regularisation for continual learning with Gaussian processes. In International Conference on Learning Representations, 2020.
  • van Amersfoort et al. (2020) Joost van Amersfoort, Lewis Smith, Yee Whye Teh, and Yarin Gal. Uncertainty estimation using a single deep deterministic neural network. In International Conference on Machine Learning, 2020.
  • van Amersfoort et al. (2021) Joost van Amersfoort, Lewis Smith, Andrew Jesson, Oscar Key, and Yarin Gal. Variational deterministic uncertainty quantification, 2021.
  • Wainwright and Jordan (2008) Martin J Wainwright and Michael I Jordan. Graphical Models, Exponential Families, and Variational Inference. Now Publishers Inc., Hanover, MA, USA, 2008. ISBN 1601981848.
  • Wang et al. (2019) Ziyu Wang, Tongzheng Ren, Jun Zhu, and Bo Zhang. Function space particle optimization for Bayesian neural networks. In International Conference on Learning Representations, 2019.
  • Widdowson (2016) D. T. S. Widdowson. The Management of Grading Quality: Good Practice in the Quality Assurance of Grading. Tech. Rep., 2016.
  • Wolpert (1993) David H. Wolpert. Bayesian backpropagation over i-o functions rather than weights. In J. Cowan, G. Tesauro, and J. Alspector, editors, Advances in Neural Information Processing Systems, volume 6. Morgan-Kaufmann, 1993.

 

Appendix

 

Table of Contents

\startcontents

[sections] \printcontents[sections]l1

Appendix A Proofs & Derivations

A.1 Function-Space Variational Objective

This proof follows steps from Matthews et al. [2016]. Consider measures P^^𝑃\hat{P}over^ start_ARG italic_P end_ARG and P𝑃Pitalic_P both of which define distributions over some function f𝑓fitalic_f, indexed by an infinite index set X𝑋Xitalic_X. Let 𝒟𝒟\mathcal{D}caligraphic_D be a dataset and let 𝐗𝒟subscript𝐗𝒟\mathbf{X}_{\mathcal{D}}bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT denote a set of inputs and 𝐲𝒟subscript𝐲𝒟\mathbf{y}_{\mathcal{D}}bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT a set of targets. Consider the measure-theoretic version of Bayes’ Theorem [Schervish, 1995]:

dP^dP(f)=pX(Y|f)p(Y),𝑑^𝑃𝑑𝑃𝑓subscript𝑝𝑋conditional𝑌𝑓𝑝𝑌\displaystyle\frac{d\hat{P}}{dP}(f)=\frac{p_{X}(Y\,|\,f)}{p(Y)},divide start_ARG italic_d over^ start_ARG italic_P end_ARG end_ARG start_ARG italic_d italic_P end_ARG ( italic_f ) = divide start_ARG italic_p start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ( italic_Y | italic_f ) end_ARG start_ARG italic_p ( italic_Y ) end_ARG , (A.1)

where pX(Y|f)subscript𝑝𝑋conditional𝑌𝑓p_{X}(Y\,|\,f)italic_p start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ( italic_Y | italic_f ) is the likelihood and p(Y)=XpX(Y|f)𝑑P(f)𝑝𝑌subscript𝑋subscript𝑝𝑋conditional𝑌𝑓differential-d𝑃𝑓p(Y)=\int_{{}^{X}}p_{X}(Y\,|\,f)dP(f)italic_p ( italic_Y ) = ∫ start_POSTSUBSCRIPT start_FLOATSUPERSCRIPT italic_X end_FLOATSUPERSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ( italic_Y | italic_f ) italic_d italic_P ( italic_f ) is the marginal likelihood. We assume that the likelihood function is evaluated at a finite subset of the index set X𝑋Xitalic_X. Denote by πC:XC\pi_{C}:{}^{X}\to{}^{C}italic_π start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT : start_FLOATSUPERSCRIPT italic_X end_FLOATSUPERSCRIPT → start_FLOATSUPERSCRIPT italic_C end_FLOATSUPERSCRIPT a projection function that takes a function and returns the same function, evaluated at a finite set of points C𝐶Citalic_C, so we can write

dP^dP(f)=dP^𝐗𝒟dP𝐗𝒟(π𝐗𝒟(f))=p(𝐲𝒟|π𝐗𝒟(f))p(𝐲𝒟),𝑑^𝑃𝑑𝑃𝑓𝑑subscript^𝑃subscript𝐗𝒟𝑑subscript𝑃subscript𝐗𝒟subscript𝜋subscript𝐗𝒟𝑓𝑝conditionalsubscript𝐲𝒟subscript𝜋subscript𝐗𝒟𝑓𝑝subscript𝐲𝒟\displaystyle\frac{d\hat{P}}{dP}(f)=\frac{d\hat{P}_{\mathbf{X}_{\mathcal{D}}}}% {dP_{\mathbf{X}_{\mathcal{D}}}}(\pi_{\mathbf{X}_{\mathcal{D}}}(f))=\frac{p(% \mathbf{y}_{\mathcal{D}}\,|\,\pi_{\mathbf{X}_{\mathcal{D}}}(f))}{p(\mathbf{y}_% {\mathcal{D}})},divide start_ARG italic_d over^ start_ARG italic_P end_ARG end_ARG start_ARG italic_d italic_P end_ARG ( italic_f ) = divide start_ARG italic_d over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG ( italic_π start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_f ) ) = divide start_ARG italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_π start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_f ) ) end_ARG start_ARG italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ) end_ARG , (A.2)

and similarly, the marginal likelihood becomes p(𝐲𝒟)=p𝐲|f𝐗(𝐲𝒟|f𝐗𝒟)dP𝐗𝒟(f𝐗𝒟)𝑝subscript𝐲𝒟subscript𝑝conditional𝐲subscript𝑓𝐗conditionalsubscript𝐲𝒟subscript𝑓subscript𝐗𝒟dsubscript𝑃subscript𝐗𝒟subscript𝑓subscript𝐗𝒟p(\mathbf{y}_{\mathcal{D}})=\int p_{\mathbf{y}|f_{\mathbf{X}}}(\mathbf{y}_{% \mathcal{D}}\,|\,f_{\mathbf{X}_{\mathcal{D}}})\,\textrm{d}P_{\mathbf{X}_{% \mathcal{D}}}(f_{\mathbf{X}_{\mathcal{D}}})italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ) = ∫ italic_p start_POSTSUBSCRIPT bold_y | italic_f start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) d italic_P start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ). Now, considering the measure-theoretic version of the KL divergence between an approximating stochastic process Q𝑄Qitalic_Q and a posterior stochastic process P^^𝑃\hat{P}over^ start_ARG italic_P end_ARG, we can write

𝔻KL(QP^)=logdQdP(f)dQ(f)logdP^dP(f)dQ(f),subscript𝔻KLconditional𝑄^𝑃𝑑𝑄𝑑𝑃𝑓d𝑄𝑓𝑑^𝑃𝑑𝑃𝑓d𝑄𝑓\displaystyle\mathbb{D}_{\textrm{KL}}(Q\,\|\,\hat{P})=\int\log{\frac{dQ}{dP}(f% )}\,\textrm{d}Q(f)-\int\log{\frac{d\hat{P}}{dP}(f)}\,\textrm{d}Q(f),blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_Q ∥ over^ start_ARG italic_P end_ARG ) = ∫ roman_log divide start_ARG italic_d italic_Q end_ARG start_ARG italic_d italic_P end_ARG ( italic_f ) d italic_Q ( italic_f ) - ∫ roman_log divide start_ARG italic_d over^ start_ARG italic_P end_ARG end_ARG start_ARG italic_d italic_P end_ARG ( italic_f ) d italic_Q ( italic_f ) , (A.3)

where P𝑃Pitalic_P is some prior stochastic process. Now, we can apply the measure-theoretic Bayes’ Theorem to obtain

𝔻KL(QP^)subscript𝔻KLconditional𝑄^𝑃\displaystyle\mathbb{D}_{\textrm{KL}}(Q\,\|\,\hat{P})blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_Q ∥ over^ start_ARG italic_P end_ARG ) =logdQdP(f)dQ(f)logdP^dP(f)dQ(f)absent𝑑𝑄𝑑𝑃𝑓d𝑄𝑓𝑑^𝑃𝑑𝑃𝑓d𝑄𝑓\displaystyle=\int\log{\frac{dQ}{dP}(f)}\,\textrm{d}Q(f)-\int\log{\frac{d\hat{% P}}{dP}(f)}\,\textrm{d}Q(f)= ∫ roman_log divide start_ARG italic_d italic_Q end_ARG start_ARG italic_d italic_P end_ARG ( italic_f ) d italic_Q ( italic_f ) - ∫ roman_log divide start_ARG italic_d over^ start_ARG italic_P end_ARG end_ARG start_ARG italic_d italic_P end_ARG ( italic_f ) d italic_Q ( italic_f ) (A.4)
=logdQπdPπ(f)dQπ(f)logdP^𝐗𝒟dP𝐗𝒟(f𝐗𝒟)dQ𝐗𝒟(f𝐗𝒟)absent𝑑superscript𝑄𝜋𝑑superscript𝑃𝜋𝑓dsuperscript𝑄𝜋𝑓𝑑subscript^𝑃subscript𝐗𝒟𝑑subscript𝑃subscript𝐗𝒟subscript𝑓subscript𝐗𝒟dsubscript𝑄subscript𝐗𝒟subscript𝑓subscript𝐗𝒟\displaystyle=\int\log{\frac{dQ^{\pi}}{dP^{\pi}}(f)}\,\textrm{d}Q^{\pi}(f)-% \int\log{\frac{d\hat{P}_{\mathbf{X}_{\mathcal{D}}}}{dP_{\mathbf{X}_{\mathcal{D% }}}}\left(f_{\mathbf{X}_{\mathcal{D}}}\right)}\,\textrm{d}Q_{\mathbf{X}_{% \mathcal{D}}}\left(f_{\mathbf{X}_{\mathcal{D}}}\right)= ∫ roman_log divide start_ARG italic_d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG ( italic_f ) d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( italic_f ) - ∫ roman_log divide start_ARG italic_d over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG ( italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) d italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) (A.5)
=logdQπdPπ(f)dQπ(f)𝔼Q𝐗𝒟[logp(𝐲𝒟|f𝐗𝒟)]logp(𝐲𝒟),absent𝑑superscript𝑄𝜋𝑑superscript𝑃𝜋𝑓dsuperscript𝑄𝜋𝑓subscript𝔼subscript𝑄subscript𝐗𝒟𝑝conditionalsubscript𝐲𝒟subscript𝑓subscript𝐗𝒟𝑝subscript𝐲𝒟\displaystyle=\int\log{\frac{dQ^{\pi}}{dP^{\pi}}(f)}\,\textrm{d}Q^{\pi}(f)-% \operatorname{\mathbb{E}}_{Q_{\mathbf{X}_{\mathcal{D}}}}\left[\log p\left(% \mathbf{y}_{\mathcal{D}}\,|\,f_{\mathbf{X}_{\mathcal{D}}}\right)\right]-\log p% (\mathbf{y}_{\mathcal{D}}),= ∫ roman_log divide start_ARG italic_d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG ( italic_f ) d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( italic_f ) - blackboard_E start_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ] - roman_log italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ) , (A.6)

where dQπdPπ(f)𝑑superscript𝑄𝜋𝑑superscript𝑃𝜋𝑓\frac{dQ^{\pi}}{dP^{\pi}}(f)divide start_ARG italic_d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG ( italic_f ) is marginally consistent given the projection π𝜋\piitalic_π. Rearranging, we can get

p(𝐲𝒟)𝑝subscript𝐲𝒟\displaystyle p(\mathbf{y}_{\mathcal{D}})italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ) =𝔼Q𝐗𝒟[logp𝐲|f𝐗(𝐲𝒟|f𝐗𝒟)]logdQπdPπ(f)dQπ(f)+𝔻KL(QπP^)absentsubscript𝔼subscript𝑄subscript𝐗𝒟subscript𝑝conditional𝐲subscript𝑓𝐗conditionalsubscript𝐲𝒟subscript𝑓subscript𝐗𝒟𝑑superscript𝑄𝜋𝑑superscript𝑃𝜋𝑓dsuperscript𝑄𝜋𝑓subscript𝔻KLconditionalsuperscript𝑄𝜋^𝑃\displaystyle=\operatorname{\mathbb{E}}_{Q_{\mathbf{X}_{\mathcal{D}}}}\left[% \log p_{\mathbf{y}|f_{\mathbf{X}}}(\mathbf{y}_{\mathcal{D}}\,|\,f_{\mathbf{X}_% {\mathcal{D}}})\right]-\int\log{\frac{dQ^{\pi}}{dP^{\pi}}(f)}\,\textrm{d}Q^{% \pi}(f)+\mathbb{D}_{\textrm{KL}}(Q^{\pi}\,\|\,\hat{P})= blackboard_E start_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ] - ∫ roman_log divide start_ARG italic_d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG ( italic_f ) d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( italic_f ) + blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ∥ over^ start_ARG italic_P end_ARG ) (A.7)
𝔼Q𝐗𝒟[logp𝐲|f𝐗(𝐲𝒟|f𝐗𝒟)]logdQπdPπ(f)dQπ(f)absentsubscript𝔼subscript𝑄subscript𝐗𝒟subscript𝑝conditional𝐲subscript𝑓𝐗conditionalsubscript𝐲𝒟subscript𝑓subscript𝐗𝒟𝑑superscript𝑄𝜋𝑑superscript𝑃𝜋𝑓dsuperscript𝑄𝜋𝑓\displaystyle\geq\operatorname{\mathbb{E}}_{Q_{\mathbf{X}_{\mathcal{D}}}}\left% [\log p_{\mathbf{y}|f_{\mathbf{X}}}(\mathbf{y}_{\mathcal{D}}\,|\,f_{\mathbf{X}% _{\mathcal{D}}})\right]-\int\log{\frac{dQ^{\pi}}{dP^{\pi}}(f)}\,\textrm{d}Q^{% \pi}(f)≥ blackboard_E start_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ] - ∫ roman_log divide start_ARG italic_d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_ARG ( italic_f ) d italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( italic_f ) (A.8)
=𝔼Q𝐗𝒟[logp𝐲|f𝐗(𝐲𝒟|f𝐗𝒟)]𝔻KL(QπPπ).absentsubscript𝔼subscript𝑄subscript𝐗𝒟subscript𝑝conditional𝐲subscript𝑓𝐗conditionalsubscript𝐲𝒟subscript𝑓subscript𝐗𝒟subscript𝔻KLconditionalsuperscript𝑄𝜋superscript𝑃𝜋\displaystyle=\operatorname{\mathbb{E}}_{Q_{\mathbf{X}_{\mathcal{D}}}}\left[% \log p_{\mathbf{y}|f_{\mathbf{X}}}(\mathbf{y}_{\mathcal{D}}\,|\,f_{\mathbf{X}_% {\mathcal{D}}})\right]-\mathbb{D}_{\textrm{KL}}(Q^{\pi}\,\|\,P^{\pi}).= blackboard_E start_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ] - blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_Q start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ∥ italic_P start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ) . (A.9)

Finally, this lower bound can equivalently be expressed as

p(𝐲𝒟)𝑝subscript𝐲𝒟\displaystyle p(\mathbf{y}_{\mathcal{D}})italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ) 𝔼Q𝐗𝒟[logp𝐲|f𝐗(𝐲𝒟|f𝐗𝒟)]𝔻KL(Q𝐗𝒟,𝐗\𝒟P𝐗𝒟,𝐗\𝒟),absentsubscript𝔼subscript𝑄subscript𝐗𝒟subscript𝑝conditional𝐲subscript𝑓𝐗conditionalsubscript𝐲𝒟subscript𝑓subscript𝐗𝒟subscript𝔻KLconditionalsubscript𝑄subscript𝐗𝒟subscript𝐗\absent𝒟subscript𝑃subscript𝐗𝒟subscript𝐗\absent𝒟\displaystyle\geq\operatorname{\mathbb{E}}_{Q_{\mathbf{X}_{\mathcal{D}}}}\left% [\log p_{\mathbf{y}|f_{\mathbf{X}}}(\mathbf{y}_{\mathcal{D}}\,|\,f_{\mathbf{X}% _{\mathcal{D}}})\right]-\mathbb{D}_{\textrm{KL}}(Q_{\mathbf{X}_{\mathcal{D}},% \mathbf{X}_{\backslash\mathcal{D}}}\,\|\,P_{\mathbf{X}_{\mathcal{D}},\mathbf{X% }_{\backslash\mathcal{D}}}),≥ blackboard_E start_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ] - blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT , bold_X start_POSTSUBSCRIPT \ caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT , bold_X start_POSTSUBSCRIPT \ caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) , (A.10)

where 𝐗\𝒟subscript𝐗\absent𝒟\mathbf{X}_{\backslash\mathcal{D}}bold_X start_POSTSUBSCRIPT \ caligraphic_D end_POSTSUBSCRIPT is an infinite index set excluding the finite index set 𝐗𝒟subscript𝐗𝒟\mathbf{X}_{\mathcal{D}}bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT, that is, 𝐗\𝒟𝐗𝒟=subscript𝐗\absent𝒟subscript𝐗𝒟\mathbf{X}_{\backslash\mathcal{D}}\cap\mathbf{X}_{\mathcal{D}}=\varnothingbold_X start_POSTSUBSCRIPT \ caligraphic_D end_POSTSUBSCRIPT ∩ bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT = ∅, or by Theorem 1 in Sun et al. [2019], we can write

p(𝐲𝒟)𝑝subscript𝐲𝒟\displaystyle p(\mathbf{y}_{\mathcal{D}})italic_p ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ) 𝔼Q𝐗𝒟[logp𝐲|f𝐗(𝐲𝒟|f𝐗𝒟)]sup𝐗𝒳𝔻KL(Q𝐗P𝐗),absentsubscript𝔼subscript𝑄subscript𝐗𝒟subscript𝑝conditional𝐲subscript𝑓𝐗conditionalsubscript𝐲𝒟subscript𝑓subscript𝐗𝒟subscriptsupremum𝐗subscript𝒳subscript𝔻KLconditionalsubscript𝑄𝐗subscript𝑃𝐗\displaystyle\geq\operatorname{\mathbb{E}}_{Q_{\mathbf{X}_{\mathcal{D}}}}\left% [\log p_{\mathbf{y}|f_{\mathbf{X}}}(\mathbf{y}_{\mathcal{D}}\,|\,f_{\mathbf{X}% _{\mathcal{D}}})\right]-\sup_{\mathbf{X}\in\mathcal{X}_{\mathbb{N}}}\mathbb{D}% _{\textrm{KL}}(Q_{\mathbf{X}}\,\|\,P_{\mathbf{X}}),≥ blackboard_E start_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT bold_y | italic_f start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_y start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ] - roman_sup start_POSTSUBSCRIPT bold_X ∈ caligraphic_X start_POSTSUBSCRIPT blackboard_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_Q start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT ) , (A.11)

where 𝒳=˙n{𝐗𝒳n|𝒳nn×D}subscript𝒳˙subscript𝑛conditional-set𝐗subscript𝒳𝑛subscript𝒳𝑛superscript𝑛𝐷\mathcal{X}_{\mathbb{N}}\,\dot{=}\,\bigcup_{n\in\mathbb{N}}\{\mathbf{X}\in% \mathcal{X}_{n}\,|\,\mathcal{X}_{n}\subseteq\mathbb{R}^{n\times D}\}caligraphic_X start_POSTSUBSCRIPT blackboard_N end_POSTSUBSCRIPT over˙ start_ARG = end_ARG ⋃ start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT { bold_X ∈ caligraphic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT | caligraphic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ⊆ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_D end_POSTSUPERSCRIPT } is the collection of all finite sets of evaluation points.

A.2 Distribution under Linearized Function Mapping

Proposition 1 (Distribution under Linearized Mapping).

For a stochastic function f(;𝚯)𝑓normal-⋅𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) defined in terms of stochastic parameters 𝚯𝚯\bm{\Theta}bold_Θ distributed according to distribution g𝚯subscript𝑔𝚯g_{\bm{\Theta}}italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT with 𝐦=˙𝔼g𝚯[𝚯]𝐦normal-˙subscript𝔼subscript𝑔𝚯𝚯\mathbf{m}\,\dot{=}\,\operatorname{\mathbb{E}}_{g_{\bm{\Theta}}}[\bm{\Theta}]bold_m over˙ start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ bold_Θ ] and 𝐒=˙Covg𝚯[𝚯]𝐒normal-˙subscriptCovsubscript𝑔𝚯delimited-[]𝚯\mathbf{S}\,\dot{=}\,\text{\emph{Cov}}_{g_{\bm{\Theta}}}[\bm{\Theta}]bold_S over˙ start_ARG = end_ARG Cov start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ bold_Θ ], denote the linearization of the stochastic function f(;𝚯)𝑓normal-⋅𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) about 𝐦𝐦\mathbf{m}bold_m by

f(;𝚯)f~(;𝐦,𝚯)=˙f(;𝐦)+𝒥(;𝐦)(𝚯𝐦),𝑓𝚯~𝑓𝐦𝚯˙𝑓𝐦𝒥𝐦𝚯𝐦\displaystyle f(\cdot\,;\bm{\Theta})\approx\smash{\tilde{f}}(\cdot\,;\mathbf{m% },\bm{\Theta})\,\dot{=}\,f(\cdot\,;\mathbf{m})+\mathcal{J}(\cdot\,;\mathbf{m})% (\bm{\Theta}-\mathbf{m}),italic_f ( ⋅ ; bold_Θ ) ≈ over~ start_ARG italic_f end_ARG ( ⋅ ; bold_m , bold_Θ ) over˙ start_ARG = end_ARG italic_f ( ⋅ ; bold_m ) + caligraphic_J ( ⋅ ; bold_m ) ( bold_Θ - bold_m ) ,

where 𝒥(;𝐦)=˙(f(;𝚯)/𝚯)|𝚯=𝐦evaluated-at𝒥normal-⋅𝐦normal-˙𝑓normal-⋅𝚯𝚯𝚯𝐦\mathcal{J}(\cdot\,;\mathbf{m})\,\dot{=}\,(\partial f(\cdot\,;\bm{\Theta})/% \partial\bm{\Theta})|_{\bm{\Theta}=\mathbf{m}}caligraphic_J ( ⋅ ; bold_m ) over˙ start_ARG = end_ARG ( ∂ italic_f ( ⋅ ; bold_Θ ) / ∂ bold_Θ ) | start_POSTSUBSCRIPT bold_Θ = bold_m end_POSTSUBSCRIPT is the Jacobian of f(;𝚯)𝑓normal-⋅𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) evaluated at 𝚯=𝐦𝚯𝐦\bm{\Theta}=\mathbf{m}bold_Θ = bold_m. Then the mean and co-variance of the distribution over the linearized mapping f~normal-~𝑓\smash{\tilde{f}}over~ start_ARG italic_f end_ARG at 𝐗,𝐗𝒳𝐗superscript𝐗normal-′𝒳\mathbf{X},\mathbf{X}^{\prime}\in\mathcal{X}bold_X , bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_X are given by

𝔼[f~(𝐗;𝚯)]𝔼~𝑓𝐗𝚯\displaystyle\SwapAboveDisplaySkip\operatorname{\mathbb{E}}[\smash{\tilde{f}}(% \mathbf{X};\bm{\Theta})]blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) ] =f(𝐗;𝐦)absent𝑓𝐗𝐦\displaystyle=f(\mathbf{X};\mathbf{m})= italic_f ( bold_X ; bold_m )
Cov[f~(𝐗;𝚯),f~(𝐗;𝚯)]Cov~𝑓𝐗𝚯~𝑓superscript𝐗𝚯\displaystyle\textrm{\emph{Cov}}[\smash{\tilde{f}}(\mathbf{X};\bm{\Theta}),% \smash{\tilde{f}}(\mathbf{X}^{\prime};\bm{\Theta})]Cov [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) , over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_Θ ) ] =𝒥(𝐗;𝐦)𝐒𝒥(𝐗;𝐦).absent𝒥𝐗𝐦𝐒𝒥superscriptsuperscript𝐗𝐦top\displaystyle=\mathcal{J}(\mathbf{X};\mathbf{m})\mathbf{S}\mathcal{J}(\mathbf{% X}^{\prime};\mathbf{m})^{\top}.= caligraphic_J ( bold_X ; bold_m ) bold_S caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT .
Proof.

We wish to find 𝔼[f~(𝐗;𝐦,𝚯)]𝔼delimited-[]~𝑓𝐗𝐦𝚯\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})]blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) ] and

Cov(f~(𝐗;𝐦,𝚯),f~(𝐗;𝐦,𝜽))=𝔼[(f~(𝐗;𝐦,𝜽)𝔼[f~(𝐗;𝐦,𝜽)])(f~(𝐗;𝐦,𝜽)𝔼[f~(𝐗;𝐦,𝜽)])].Cov~𝑓𝐗𝐦𝚯~𝑓superscript𝐗𝐦𝜽𝔼delimited-[]~𝑓𝐗𝐦𝜽𝔼delimited-[]~𝑓𝐗𝐦𝜽superscript~𝑓superscript𝐗𝐦𝜽𝔼delimited-[]~𝑓superscript𝐗𝐦𝜽top\displaystyle\begin{split}&\textrm{Cov}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m% },\bm{\Theta}),\smash{\tilde{f}}(\mathbf{X}^{\prime};\mathbf{m},{\bm{\theta}})% )\\ &=\mathbb{E}[(\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{\bm{\theta}})-\mathbb{E% }[\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{\bm{\theta}})])\,(\smash{\tilde{f}}% (\mathbf{X}^{\prime};\mathbf{m},{\bm{\theta}})-\mathbb{E}[\smash{\tilde{f}}(% \mathbf{X}^{\prime};\mathbf{m},{\bm{\theta}})])^{\top}].\end{split}start_ROW start_CELL end_CELL start_CELL Cov ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) , over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_italic_θ ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = blackboard_E [ ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) - blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) ] ) ( over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_italic_θ ) - blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_italic_θ ) ] ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] . end_CELL end_ROW (A.12)

To see that 𝔼[f~(𝐗;𝐦,𝜽)]=f(𝐗;𝐦)𝔼delimited-[]~𝑓𝐗𝐦𝜽𝑓𝐗𝐦\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{\bm{\theta}})]=f(\mathbf{X% };\mathbf{m})blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) ] = italic_f ( bold_X ; bold_m ), note that, by linearity of expectation, we have

𝔼[f~(𝐗;𝐦,𝜽)]=𝔼[f(𝐗;𝐦)+𝒥(𝐗;𝐦)(𝚯𝐦)]=f(𝐗;𝐦)+𝒥(𝐗;𝐦)(𝔼[𝚯]𝐦)=f(𝐗;𝐦).𝔼delimited-[]~𝑓𝐗𝐦𝜽𝔼delimited-[]𝑓𝐗𝐦𝒥𝐗𝐦𝚯𝐦𝑓𝐗𝐦𝒥𝐗𝐦𝔼delimited-[]𝚯𝐦𝑓𝐗𝐦\displaystyle\begin{split}\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{% \bm{\theta}})]&=\mathbb{E}[f(\mathbf{X};\mathbf{m})+\mathcal{J}(\mathbf{X};% \mathbf{m})(\bm{\Theta}-\mathbf{m})]\\ &=f(\mathbf{X};\mathbf{m})+\mathcal{J}(\mathbf{X};\mathbf{m})(\mathbb{E}[\bm{% \Theta}]-\mathbf{m})=f(\mathbf{X};\mathbf{m}).\end{split}start_ROW start_CELL blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) ] end_CELL start_CELL = blackboard_E [ italic_f ( bold_X ; bold_m ) + caligraphic_J ( bold_X ; bold_m ) ( bold_Θ - bold_m ) ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_f ( bold_X ; bold_m ) + caligraphic_J ( bold_X ; bold_m ) ( blackboard_E [ bold_Θ ] - bold_m ) = italic_f ( bold_X ; bold_m ) . end_CELL end_ROW (A.13)

To see that Cov(f~(𝐗;𝐦,𝜽),f~(𝐗;𝐦,𝜽))=𝒥(𝐗;𝐦)𝐒𝒥(𝐗;𝐦)Cov~𝑓𝐗𝐦𝜽~𝑓superscript𝐗𝐦𝜽𝒥𝐗𝐦𝐒𝒥superscriptsuperscript𝐗𝐦top\textrm{Cov}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{\bm{\theta}}),\smash{% \tilde{f}}(\mathbf{X}^{\prime};\mathbf{m},{\bm{\theta}}))=\mathcal{J}(\mathbf{% X};\mathbf{m})\mathbf{S}\mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})^{\top}Cov ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) , over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_italic_θ ) ) = caligraphic_J ( bold_X ; bold_m ) bold_S caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, note that in general, for a multivariate random variable 𝐙𝐙\mathbf{Z}bold_Z, Cov(𝐙,𝐙)=𝔼[𝐙𝐙]+𝔼[𝐙]𝔼[𝐙]Cov𝐙𝐙𝔼delimited-[]superscript𝐙𝐙top𝔼delimited-[]𝐙𝔼superscriptdelimited-[]𝐙top\textrm{Cov}(\mathbf{Z},\mathbf{Z})=\mathbb{E}[\mathbf{Z}\mathbf{Z}^{\top}]+% \mathbb{E}[\mathbf{Z}]\mathbb{E}[\mathbf{Z}]^{\top}Cov ( bold_Z , bold_Z ) = blackboard_E [ bold_ZZ start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] + blackboard_E [ bold_Z ] blackboard_E [ bold_Z ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, and hence,

Cov(f~(𝐗;𝐦,𝚯),f~(𝐗;𝐦,𝚯))=𝔼[f~(𝐗;𝐦,𝚯)f~(𝐗;𝐦,𝚯)]𝔼[f~(𝐗;𝐦,𝚯)]𝔼[f~(𝐗;𝐦,𝚯)].Cov~𝑓𝐗𝐦𝚯~𝑓superscript𝐗𝐦𝚯𝔼delimited-[]~𝑓𝐗𝐦𝚯~𝑓superscriptsuperscript𝐗𝐦𝚯top𝔼delimited-[]~𝑓𝐗𝐦𝚯𝔼superscriptdelimited-[]~𝑓superscript𝐗𝐦𝚯top\displaystyle\begin{split}&\textrm{Cov}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m% },\bm{\Theta}),\smash{\tilde{f}}(\mathbf{X}^{\prime};\mathbf{m},\bm{\Theta}))% \\ &=\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})\smash{\tilde% {f}}(\mathbf{X}^{\prime};\mathbf{m},\bm{\Theta})^{\top}]-\mathbb{E}[\smash{% \tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})]\mathbb{E}[\smash{\tilde{f}}(% \mathbf{X}^{\prime};\mathbf{m},\bm{\Theta})]^{\top}.\end{split}start_ROW start_CELL end_CELL start_CELL Cov ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) , over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] - blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) ] blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . end_CELL end_ROW (A.14)

We already know that 𝔼[f~(𝐗;𝚯)]=f(𝐗;𝐦)𝔼delimited-[]~𝑓𝐗𝚯𝑓𝐗𝐦\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})]=f(\mathbf{X};\mathbf{m})blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) ] = italic_f ( bold_X ; bold_m ), so we only need to find 𝔼[f~(𝐗;𝚯)f~(𝐗;𝚯)]𝔼delimited-[]~𝑓𝐗𝚯~𝑓superscriptsuperscript𝐗𝚯top\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})\smash{\tilde{f}}(\mathbf{% X}^{\prime};\bm{\Theta})^{\top}]blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_Θ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]:

𝔼g𝚯[f~(𝐗;𝐦,𝚯)f~(𝐗;𝐦,𝚯)]=𝔼g𝚯[(f(𝐗;𝐦)+𝒥(𝐗;𝐦)(𝚯𝐦))(f(𝐗;𝐦)+𝒥(𝐗;𝐦)(𝚯𝐦))]subscript𝔼subscript𝑔𝚯delimited-[]~𝑓𝐗𝐦𝚯~𝑓superscriptsuperscript𝐗𝐦𝚯topsubscript𝔼subscript𝑔𝚯delimited-[]𝑓𝐗𝐦𝒥𝐗𝐦𝚯𝐦superscript𝑓superscript𝐗𝐦𝒥superscript𝐗𝐦𝚯𝐦top\displaystyle\begin{split}\mathbb{E}_{g_{\bm{\Theta}}}&[\smash{\tilde{f}}(% \mathbf{X};\mathbf{m},\bm{\Theta})\smash{\tilde{f}}(\mathbf{X}^{\prime};% \mathbf{m},\bm{\Theta})^{\top}]\\ =&\mathbb{E}_{g_{\bm{\Theta}}}[(f(\mathbf{X};\mathbf{m})+\mathcal{J}(\mathbf{X% };\mathbf{m})(\bm{\Theta}-\mathbf{m}))(f(\mathbf{X}^{\prime};\mathbf{m})+% \mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})(\bm{\Theta}-\mathbf{m}))^{\top}]% \end{split}start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ( italic_f ( bold_X ; bold_m ) + caligraphic_J ( bold_X ; bold_m ) ( bold_Θ - bold_m ) ) ( italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) + caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) ( bold_Θ - bold_m ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] end_CELL end_ROW (A.15)
=𝔼g𝚯[f(𝐗;𝐦)f(𝐗;𝐦)+(𝒥(𝐗;𝐦)(𝚯𝐦))(𝒥(𝐗;𝐦)(𝚯𝐦))+f(𝐗;𝐦)(𝒥(𝐗;𝐦)(𝚯𝐦))+𝒥(𝐗;𝐦)(𝚯𝐦)f(𝐗;𝐦)]absentsubscript𝔼subscript𝑔𝚯𝑓𝐗𝐦𝑓superscriptsuperscript𝐗𝐦top𝒥𝐗𝐦𝚯𝐦superscript𝒥superscript𝐗𝐦𝚯𝐦top𝑓𝐗𝐦superscript𝒥superscript𝐗𝐦𝚯𝐦top𝒥𝐗𝐦𝚯𝐦𝑓superscriptsuperscript𝐗𝐦top\displaystyle\begin{split}=&\mathbb{E}_{g_{\bm{\Theta}}}[f(\mathbf{X};\mathbf{% m})f(\mathbf{X}^{\prime};\mathbf{m})^{\top}+(\mathcal{J}(\mathbf{X};\mathbf{m}% )(\bm{\Theta}-\mathbf{m}))(\mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})(\bm{% \Theta}-\mathbf{m}))^{\top}\\ &\qquad\qquad+f(\mathbf{X};\mathbf{m})(\mathcal{J}(\mathbf{X}^{\prime};\mathbf% {m})(\bm{\Theta}-\mathbf{m}))^{\top}+\mathcal{J}(\mathbf{X};\mathbf{m})(\bm{% \Theta}-\mathbf{m})f(\mathbf{X}^{\prime};\mathbf{m})^{\top}]\end{split}start_ROW start_CELL = end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_f ( bold_X ; bold_m ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + ( caligraphic_J ( bold_X ; bold_m ) ( bold_Θ - bold_m ) ) ( caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) ( bold_Θ - bold_m ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + italic_f ( bold_X ; bold_m ) ( caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) ( bold_Θ - bold_m ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) ( bold_Θ - bold_m ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] end_CELL end_ROW (A.16)
=𝔼g𝚯[f(𝐗;𝐦)f(𝐗;𝐦)+𝒥(𝐗;𝐦)(𝚯𝐦)(𝚯𝐦)𝒥(𝐗;𝐦)+f(𝐗;𝐦)(𝒥(𝐗;𝐦)(𝚯𝐦))+𝒥(𝐗;𝐦)(𝚯𝐦)f(𝐗;𝐦)]absentsubscript𝔼subscript𝑔𝚯𝑓𝐗𝐦𝑓superscriptsuperscript𝐗𝐦top𝒥𝐗𝐦𝚯𝐦superscript𝚯𝐦top𝒥superscriptsuperscript𝐗𝐦top𝑓𝐗𝐦superscript𝒥superscript𝐗𝐦𝚯𝐦top𝒥𝐗𝐦𝚯𝐦𝑓superscriptsuperscript𝐗𝐦top\displaystyle\begin{split}=&\mathbb{E}_{g_{\bm{\Theta}}}[f(\mathbf{X};\mathbf{% m})f(\mathbf{X}^{\prime};\mathbf{m})^{\top}+\mathcal{J}(\mathbf{X};\mathbf{m})% (\bm{\Theta}-\mathbf{m})(\bm{\Theta}-\mathbf{m})^{\top}\mathcal{J}(\mathbf{X}^% {\prime};\mathbf{m})^{\top}\\ &\qquad\qquad+f(\mathbf{X};\mathbf{m})(\mathcal{J}(\mathbf{X}^{\prime};\mathbf% {m})(\bm{\Theta}-\mathbf{m}))^{\top}+\mathcal{J}(\mathbf{X};\mathbf{m})(\bm{% \Theta}-\mathbf{m})f(\mathbf{X}^{\prime};\mathbf{m})^{\top}]\end{split}start_ROW start_CELL = end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_f ( bold_X ; bold_m ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) ( bold_Θ - bold_m ) ( bold_Θ - bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + italic_f ( bold_X ; bold_m ) ( caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) ( bold_Θ - bold_m ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) ( bold_Θ - bold_m ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] end_CELL end_ROW (A.17)
=f(𝐗;𝐦)f(𝐗;𝐦)+𝒥(𝐗;𝐦)𝔼g𝚯[(𝚯𝐦)(𝚯𝐦)]𝒥(𝐗;𝐦)+f(𝐗;𝐦)(𝒥(𝐗;𝐦)(𝔼g𝚯[𝚯]𝐦)=0)+𝒥(𝐗;𝐦)(𝔼g𝚯[𝚯]𝐦=0)f(𝐗;𝐦),\displaystyle\begin{split}=&f(\mathbf{X};\mathbf{m})f(\mathbf{X}^{\prime};% \mathbf{m})^{\top}+\mathcal{J}(\mathbf{X};\mathbf{m})\mathbb{E}_{g_{\bm{\Theta% }}}[(\bm{\Theta}-\mathbf{m})(\bm{\Theta}-\mathbf{m})^{\top}]\mathcal{J}(% \mathbf{X}^{\prime};\mathbf{m})^{\top}\\ &\qquad\qquad+f(\mathbf{X};\mathbf{m})(\mathcal{J}(\mathbf{X}^{\prime};\mathbf% {m})(\underbrace{\mathbb{E}_{g_{\bm{\Theta}}}[\bm{\Theta}]-\mathbf{m})}_{=0})^% {\top}+\mathcal{J}(\mathbf{X};\mathbf{m})(\underbrace{\mathbb{E}_{g_{\bm{% \Theta}}}[\bm{\Theta}]-\mathbf{m}}_{=0})f(\mathbf{X}^{\prime};\mathbf{m})^{% \top},\end{split}start_ROW start_CELL = end_CELL start_CELL italic_f ( bold_X ; bold_m ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ( bold_Θ - bold_m ) ( bold_Θ - bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + italic_f ( bold_X ; bold_m ) ( caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) ( under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ bold_Θ ] - bold_m ) end_ARG start_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) ( under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ bold_Θ ] - bold_m end_ARG start_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , end_CELL end_ROW (A.18)

where the last line follows from the definition of g𝚯subscript𝑔𝚯g_{\bm{\Theta}}italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT. By definition of the covariance, we then obtain

𝔼g𝚯[f~(𝐗;𝐦,𝚯)f~(𝐗;𝐦,𝚯)]=f(𝐗;𝐦)f(𝐗;𝐦)+𝒥(𝐗;𝐦)𝔼g𝚯[(𝚯𝐦)(𝚯𝐦)]𝒥(𝐗;𝐦)subscript𝔼subscript𝑔𝚯delimited-[]~𝑓𝐗𝐦𝚯~𝑓superscriptsuperscript𝐗𝐦𝚯top𝑓𝐗𝐦𝑓superscriptsuperscript𝐗𝐦top𝒥𝐗𝐦subscript𝔼subscript𝑔𝚯delimited-[]𝚯𝐦superscript𝚯𝐦top𝒥superscriptsuperscript𝐗𝐦top\displaystyle\begin{split}&\mathbb{E}_{g_{\bm{\Theta}}}[\smash{\tilde{f}}(% \mathbf{X};\mathbf{m},\bm{\Theta})\smash{\tilde{f}}(\mathbf{X}^{\prime};% \mathbf{m},\bm{\Theta})^{\top}]\\ &=f(\mathbf{X};\mathbf{m})f(\mathbf{X}^{\prime};\mathbf{m})^{\top}+\mathcal{J}% (\mathbf{X};\mathbf{m})\mathbb{E}_{g_{\bm{\Theta}}}[(\bm{\Theta}-\mathbf{m})(% \bm{\Theta}-\mathbf{m})^{\top}]\mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})^{% \top}\end{split}start_ROW start_CELL end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_f ( bold_X ; bold_m ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) blackboard_E start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ( bold_Θ - bold_m ) ( bold_Θ - bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW (A.19)
=f(𝐗;𝐦)f(𝐗;𝐦)+𝒥(𝐗;𝐦)Cov(𝚯)𝒥(𝐗;𝐦).absent𝑓𝐗𝐦𝑓superscript𝐗𝐦top𝒥𝐗𝐦Cov𝚯𝒥superscriptsuperscript𝐗𝐦top\displaystyle=f(\mathbf{X};\mathbf{m})f(\mathbf{X};\mathbf{m})^{\top}+\mathcal% {J}(\mathbf{X};\mathbf{m})\text{Cov}(\bm{\Theta})\mathcal{J}(\mathbf{X}^{% \prime};\mathbf{m})^{\top}.= italic_f ( bold_X ; bold_m ) italic_f ( bold_X ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) Cov ( bold_Θ ) caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (A.20)

With this result, we obtain the covariance function

Cov(f~(𝐗;𝐦,𝚯),f~(𝐗;𝐦,𝚯))=𝔼[f~(𝐗;𝐦,𝚯)f~(𝐗;𝐦,𝚯)]𝔼[f~(𝐗;𝐦,𝚯)]𝔼[f~(𝐗;𝐦,𝚯)]Cov~𝑓𝐗𝐦𝚯~𝑓superscript𝐗𝐦𝚯𝔼delimited-[]~𝑓𝐗𝐦𝚯~𝑓superscriptsuperscript𝐗𝐦𝚯top𝔼delimited-[]~𝑓𝐗𝐦𝚯𝔼superscriptdelimited-[]~𝑓superscript𝐗𝐦𝚯top\displaystyle\begin{split}&\textrm{Cov}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m% },\bm{\Theta}),\smash{\tilde{f}}(\mathbf{X}^{\prime};\mathbf{m},\bm{\Theta}))% \\ &=\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})\smash{\tilde% {f}}(\mathbf{X}^{\prime};\mathbf{m},\bm{\Theta})^{\top}]-\mathbb{E}[\smash{% \tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})]\mathbb{E}[\smash{\tilde{f}}(% \mathbf{X}^{\prime};\mathbf{m},\bm{\Theta})]^{\top}\end{split}start_ROW start_CELL end_CELL start_CELL Cov ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) , over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] - blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) ] blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW (A.21)
=𝔼[f~(𝐗;𝐦,𝚯)f~(𝐗;𝐦,𝚯)]f(𝐗;𝐦)f(𝐗;𝐦)+𝒥(𝐗;𝐦)Cov(𝚯)𝒥(𝐗;𝐦)absent𝔼delimited-[]~𝑓𝐗𝐦𝚯~𝑓superscriptsuperscript𝐗𝐦𝚯top𝑓𝐗𝐦𝑓superscript𝐗𝐦top𝒥𝐗𝐦Cov𝚯𝒥superscriptsuperscript𝐗𝐦top\displaystyle=\mathbb{E}[\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})% \smash{\tilde{f}}(\mathbf{X}^{\prime};\mathbf{m},\bm{\Theta})^{\top}]-f(% \mathbf{X};\mathbf{m})f(\mathbf{X};\mathbf{m})^{\top}+\mathcal{J}(\mathbf{X};% \mathbf{m})\text{Cov}(\bm{\Theta})\mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})^% {\top}= blackboard_E [ over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] - italic_f ( bold_X ; bold_m ) italic_f ( bold_X ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) Cov ( bold_Θ ) caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (A.22)
=f(𝐗;𝚯)f(𝐗;𝚯)f(𝐗;𝐦)f(𝐗;𝐦)+𝒥(𝐗;𝐦)Cov𝚯)𝒥(𝐗;𝐦)\displaystyle=f(\mathbf{X};\bm{\Theta})f(\mathbf{X}^{\prime};\bm{\Theta})^{% \top}-f(\mathbf{X};\mathbf{m})f(\mathbf{X};\mathbf{m})^{\top}+\mathcal{J}(% \mathbf{X};\mathbf{m})\text{Cov}\bm{\Theta})\mathcal{J}(\mathbf{X}^{\prime};% \mathbf{m})^{\top}= italic_f ( bold_X ; bold_Θ ) italic_f ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_Θ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_f ( bold_X ; bold_m ) italic_f ( bold_X ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + caligraphic_J ( bold_X ; bold_m ) Cov bold_Θ ) caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (A.23)
=𝒥(𝐗;𝐦)𝕍[𝚯]𝒥(𝐗;𝐦).absent𝒥𝐗𝐦𝕍𝚯𝒥superscriptsuperscript𝐗𝐦top\displaystyle=\mathcal{J}(\mathbf{X};\mathbf{m})\operatorname{\mathbb{V}}[\bm{% \Theta}]\mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})^{\top}.= caligraphic_J ( bold_X ; bold_m ) blackboard_V [ bold_Θ ] caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (A.24)

Finally, Cov(𝚯)=𝐒Cov𝚯𝐒\text{Cov}(\bm{\Theta})=\mathbf{S}Cov ( bold_Θ ) = bold_S yields Cov(f~(𝐗;𝐦,𝚯),f~(𝐗;𝐦,𝚯))=𝒥(𝐗;𝐦)𝐒𝒥(𝐗;𝐦)Cov~𝑓𝐗𝐦𝚯~𝑓superscript𝐗𝐦𝚯𝒥𝐗𝐦𝐒𝒥superscriptsuperscript𝐗𝐦top\textrm{Cov}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta}),\smash{% \tilde{f}}(\mathbf{X}^{\prime};\mathbf{m},\bm{\Theta}))=\mathcal{J}(\mathbf{X}% ;\mathbf{m})\mathbf{S}\mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})^{\top}Cov ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) , over~ start_ARG italic_f end_ARG ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m , bold_Θ ) ) = caligraphic_J ( bold_X ; bold_m ) bold_S caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. This concludes the proof. ∎

Proposition 2 (Approximate Distribution under Linearized Mapping).

For a stochastic function f(;𝚯)𝑓normal-⋅𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) defined in terms of stochastic parameters 𝚯𝚯\bm{\Theta}bold_Θ distributed according to distribution g𝚯=𝒩(𝐦,𝐒)subscript𝑔𝚯𝒩𝐦𝐒g_{\bm{\Theta}}=\mathcal{N}(\mathbf{m},\mathbf{S})italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT = caligraphic_N ( bold_m , bold_S ), denote the linearization of the stochastic function f(;𝚯)𝑓normal-⋅𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) about 𝐦𝐦\mathbf{m}bold_m by

f(;𝚯)f~(;𝐦,𝚯)=˙f(;𝐦)+𝒥(;𝐦)(𝚯𝐦),𝑓𝚯~𝑓𝐦𝚯˙𝑓𝐦𝒥𝐦𝚯𝐦\displaystyle f(\cdot\,;\bm{\Theta})\approx\smash{\tilde{f}}(\cdot\,;\mathbf{m% },\bm{\Theta})\,\dot{=}\,f(\cdot\,;\mathbf{m})+\mathcal{J}(\cdot\,;\mathbf{m})% (\bm{\Theta}-\mathbf{m}),italic_f ( ⋅ ; bold_Θ ) ≈ over~ start_ARG italic_f end_ARG ( ⋅ ; bold_m , bold_Θ ) over˙ start_ARG = end_ARG italic_f ( ⋅ ; bold_m ) + caligraphic_J ( ⋅ ; bold_m ) ( bold_Θ - bold_m ) ,

where 𝒥(;𝐦)=˙(f(;𝚯)/𝚯)|𝚯=𝐦evaluated-at𝒥normal-⋅𝐦normal-˙𝑓normal-⋅𝚯𝚯𝚯𝐦\mathcal{J}(\cdot\,;\mathbf{m})\,\dot{=}\,(\partial f(\cdot\,;\bm{\Theta})/% \partial\bm{\Theta})|_{\bm{\Theta}=\mathbf{m}}caligraphic_J ( ⋅ ; bold_m ) over˙ start_ARG = end_ARG ( ∂ italic_f ( ⋅ ; bold_Θ ) / ∂ bold_Θ ) | start_POSTSUBSCRIPT bold_Θ = bold_m end_POSTSUBSCRIPT is the Jacobian of f(;𝚯)𝑓normal-⋅𝚯f(\cdot\,;\bm{\Theta})italic_f ( ⋅ ; bold_Θ ) evaluated at 𝚯=𝐦𝚯𝐦\bm{\Theta}=\mathbf{m}bold_Θ = bold_m. Then, for a partition of the set of parameters into sets α𝛼\alphaitalic_α and β𝛽\betaitalic_β, a distribution g𝚯=𝒩(𝐦,𝐒)subscript𝑔𝚯𝒩𝐦𝐒g_{\bm{\Theta}}=\mathcal{N}(\mathbf{m},\mathbf{S})italic_g start_POSTSUBSCRIPT bold_Θ end_POSTSUBSCRIPT = caligraphic_N ( bold_m , bold_S ) with 𝚯α𝚯βperpendicular-tosubscript𝚯𝛼subscript𝚯𝛽\bm{\Theta}_{\alpha}\perp\bm{\Theta}_{\beta}bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ⟂ bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT, the distribution g~f~(𝐗;𝚯)subscriptnormal-~𝑔normal-~𝑓𝐗𝚯\tilde{g}_{\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ) end_POSTSUBSCRIPT can be approximated via the Monte Carlo estimator

g~^f~(𝐗;𝐦,𝚯)=1Rj=1R𝒩(f(𝐗;𝐦)+f~α(𝐗;𝐦,𝚯α)(j),𝒥β(𝐗;𝐦)𝐒β𝒥(𝐗;𝐦)β),subscript^~𝑔~𝑓𝐗𝐦𝚯1𝑅superscriptsubscript𝑗1𝑅𝒩𝑓𝐗𝐦subscript~𝑓𝛼superscript𝐗𝐦subscript𝚯𝛼𝑗subscript𝒥𝛽𝐗𝐦subscript𝐒𝛽𝒥superscriptsubscriptsuperscript𝐗𝐦𝛽top\displaystyle\hat{\tilde{g}}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{% \Theta})}=\frac{1}{R}\sum\nolimits_{j=1}^{R}\mathcal{N}\Big{(}f(\mathbf{X};% \mathbf{m})+\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{% \alpha})^{(j)},\mathcal{J}_{\beta}(\mathbf{X};\mathbf{m})\mathbf{S}_{\beta}{% \mathcal{J}(\mathbf{X}^{\prime};\mathbf{m})_{\beta}}^{\top}\Big{)},over^ start_ARG over~ start_ARG italic_g end_ARG end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_R end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT caligraphic_N ( italic_f ( bold_X ; bold_m ) + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT , caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m ) bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT caligraphic_J ( bold_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; bold_m ) start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) , (A.25)

where g𝚯α=𝒩(𝐦α,𝐒α)subscript𝑔subscript𝚯𝛼𝒩subscript𝐦𝛼subscript𝐒𝛼g_{\bm{\Theta}_{\alpha}}=\mathcal{N}(\mathbf{m}_{\alpha},\mathbf{S}_{\alpha})italic_g start_POSTSUBSCRIPT bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( bold_m start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ), g𝚯β=𝒩(𝐦β,𝐒β)subscript𝑔subscript𝚯𝛽𝒩subscript𝐦𝛽subscript𝐒𝛽g_{\bm{\Theta}_{\beta}}=\mathcal{N}(\mathbf{m}_{\beta},\mathbf{S}_{\beta})italic_g start_POSTSUBSCRIPT bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( bold_m start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ), and

f~α(;𝐦,𝚯α)=˙𝒥α(;𝐦)(𝚯α𝐦α),subscript~𝑓𝛼𝐦subscript𝚯𝛼˙subscript𝒥𝛼𝐦subscript𝚯𝛼subscript𝐦𝛼\displaystyle\smash{\tilde{f}}_{\alpha}(\cdot\,;\mathbf{m},\bm{\Theta}_{\alpha% })\,\dot{=}\,\mathcal{J}_{\alpha}(\cdot\,;\mathbf{m})(\bm{\Theta}_{\alpha}-% \mathbf{m}_{\alpha}),over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) over˙ start_ARG = end_ARG caligraphic_J start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m ) ( bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - bold_m start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) , (A.26)

with 𝒥α(;𝐦)subscript𝒥𝛼normal-⋅𝐦\mathcal{J}_{\alpha}(\cdot\,;\mathbf{m})caligraphic_J start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m ) denoting the columns of the Jacobian matrix corresponding to the sets of parameters α𝛼\alphaitalic_α and f~α(𝐗;𝐦,𝚯α)(j)subscriptnormal-~𝑓𝛼superscript𝐗𝐦subscript𝚯𝛼𝑗\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{\alpha})^{(j)}over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT for j=1,,R𝑗1normal-…𝑅j=1,...,Ritalic_j = 1 , … , italic_R obtained by sampling parameters from the distribution g𝚯α=𝒩(𝐦α,𝐒α)subscript𝑔subscript𝚯𝛼𝒩subscript𝐦𝛼subscript𝐒𝛼g_{\bm{\Theta}_{\alpha}}=\mathcal{N}(\mathbf{m}_{\alpha},\mathbf{S}_{\alpha})italic_g start_POSTSUBSCRIPT bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( bold_m start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ).

Proof.

Consider a partition of the set of parameters into sets α𝛼\alphaitalic_α and β𝛽\betaitalic_β and express the linearized mapping as

f~(;𝐦,𝚯)=f~α(;𝐦,𝚯α)+f~β(;𝐦,𝚯β),~𝑓𝐦𝚯subscript~𝑓𝛼𝐦subscript𝚯𝛼subscript~𝑓𝛽𝐦subscript𝚯𝛽\displaystyle\SwapAboveDisplaySkip\smash{\tilde{f}}(\cdot\,;\mathbf{m},\bm{% \Theta})=\smash{\tilde{f}}_{\alpha}(\cdot\,;\mathbf{m},\bm{\Theta}_{\alpha})+% \smash{\tilde{f}}_{\beta}(\cdot\,;\mathbf{m},\bm{\Theta}_{\beta}),over~ start_ARG italic_f end_ARG ( ⋅ ; bold_m , bold_Θ ) = over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( ⋅ ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) , (A.27)

with

f~α(;𝐦,𝚯α)=˙𝒥α(;𝐦)(𝚯α𝐦α),subscript~𝑓𝛼𝐦subscript𝚯𝛼˙subscript𝒥𝛼𝐦subscript𝚯𝛼subscript𝐦𝛼\displaystyle\SwapAboveDisplaySkip\smash{\tilde{f}}_{\alpha}(\cdot\,;\mathbf{m% },\bm{\Theta}_{\alpha})\,\dot{=}\,\mathcal{J}_{\alpha}(\cdot\,;\mathbf{m})(\bm% {\Theta}_{\alpha}-\mathbf{m}_{\alpha}),over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) over˙ start_ARG = end_ARG caligraphic_J start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m ) ( bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - bold_m start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) , (A.28)

and

f~β(;𝐦,𝚯β)=˙f(;𝐦)+𝒥β(;𝐦)(𝚯β𝐦β),subscript~𝑓𝛽𝐦subscript𝚯𝛽˙𝑓𝐦subscript𝒥𝛽𝐦subscript𝚯𝛽subscript𝐦𝛽\displaystyle\SwapAboveDisplaySkip\smash{\tilde{f}}_{\beta}(\cdot\,;\mathbf{m}% ,\bm{\Theta}_{\beta})\,\dot{=}\,f(\cdot\,;\mathbf{m})+\mathcal{J}_{\beta}(% \cdot\,;\mathbf{m})(\bm{\Theta}_{\beta}-\mathbf{m}_{\beta}),over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( ⋅ ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) over˙ start_ARG = end_ARG italic_f ( ⋅ ; bold_m ) + caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( ⋅ ; bold_m ) ( bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT - bold_m start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) , (A.29)

where 𝒥α(;𝐦)subscript𝒥𝛼𝐦\mathcal{J}_{\alpha}(\cdot\,;\mathbf{m})caligraphic_J start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( ⋅ ; bold_m ) and 𝒥β(;𝐦)subscript𝒥𝛽𝐦\mathcal{J}_{\beta}(\cdot\,;\mathbf{m})caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( ⋅ ; bold_m ) are the columns of the Jacobian matrix corresponding to the sets of parameters α𝛼\alphaitalic_α and β𝛽\betaitalic_β, respectively, and 𝚯αsubscript𝚯𝛼\bm{\Theta}_{\alpha}bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT and 𝚯βsubscript𝚯𝛽\bm{\Theta}_{\beta}bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT are the corresponding random parameter vectors.

Noting that Equation A.27 expresses f~~𝑓\smash{\tilde{f}}over~ start_ARG italic_f end_ARG as a sum of (affine transformations of) random variables, we can use the fact that for independent Gaussian random variables 𝐗𝐗\mathbf{X}bold_X and 𝐘𝐘\mathbf{Y}bold_Y, the distribution h𝐙subscript𝐙h_{\mathbf{Z}}italic_h start_POSTSUBSCRIPT bold_Z end_POSTSUBSCRIPT of 𝐙=𝐗+𝐘𝐙𝐗𝐘\mathbf{Z}=\mathbf{X}+\mathbf{Y}bold_Z = bold_X + bold_Y is equal to the convolution of the distributions h𝐗subscript𝐗h_{\mathbf{X}}italic_h start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT and h𝐘subscript𝐘h_{\mathbf{Y}}italic_h start_POSTSUBSCRIPT bold_Y end_POSTSUBSCRIPT to obtain an approximation to f~~𝑓\smash{\tilde{f}}over~ start_ARG italic_f end_ARG. In particular, for 𝐙=𝐗+𝐘𝐙𝐗𝐘\mathbf{Z}=\mathbf{X}+\mathbf{Y}bold_Z = bold_X + bold_Y,

f𝐙(𝐳)=f𝐘(𝐳𝐱)f𝐗(𝐱)d𝐱.subscript𝑓𝐙𝐳superscriptsubscriptsubscript𝑓𝐘𝐳𝐱subscript𝑓𝐗𝐱d𝐱\displaystyle\SwapAboveDisplaySkip f_{\mathbf{Z}}(\mathbf{z})=\int_{-\infty}^{% \infty}f_{\mathbf{Y}}(\mathbf{z}-\mathbf{x})f_{\mathbf{X}}(\mathbf{x})\,% \textrm{d}\mathbf{x}.italic_f start_POSTSUBSCRIPT bold_Z end_POSTSUBSCRIPT ( bold_z ) = ∫ start_POSTSUBSCRIPT - ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT bold_Y end_POSTSUBSCRIPT ( bold_z - bold_x ) italic_f start_POSTSUBSCRIPT bold_X end_POSTSUBSCRIPT ( bold_x ) d bold_x . (A.30)

Letting 𝐗=f~α(𝐗;𝐦,𝚯α)𝐗subscript~𝑓𝛼𝐗𝐦subscript𝚯𝛼\mathbf{X}=\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{% \alpha})bold_X = over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ), 𝐘=f~β(𝐗;𝐦,𝚯β)𝐘subscript~𝑓𝛽𝐗𝐦subscript𝚯𝛽\mathbf{Y}=\smash{\tilde{f}}_{\beta}(\mathbf{X};\mathbf{m},\bm{\Theta}_{\beta})bold_Y = over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ), and 𝐗=f~(𝐗;𝚯)𝐗~𝑓𝐗𝚯\mathbf{X}=\smash{\tilde{f}}(\mathbf{X};\bm{\Theta})bold_X = over~ start_ARG italic_f end_ARG ( bold_X ; bold_Θ ), we can write

g~f~(𝐗;𝐦,𝚯)(f~(𝐗;𝐦,𝜽))=g~f~β(𝐗;𝐦,𝚯β)(f~(𝐗;𝐦,𝜽)f~α(𝐗;𝐦,𝜽α))g~f~α(𝐗;𝐦,𝚯α)(f~α(𝐗;𝐦,𝜽α))d𝐗,subscript~𝑔~𝑓𝐗𝐦𝚯~𝑓𝐗𝐦𝜽superscriptsubscriptsubscript~𝑔subscript~𝑓𝛽𝐗𝐦subscript𝚯𝛽~𝑓𝐗𝐦𝜽subscript~𝑓𝛼𝐗𝐦subscript𝜽𝛼subscript~𝑔subscript~𝑓𝛼𝐗𝐦subscript𝚯𝛼subscript~𝑓𝛼𝐗𝐦subscript𝜽𝛼d𝐗\displaystyle\begin{split}&\tilde{g}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},% \bm{\Theta})}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{\bm{\theta}}))\\ &=\int_{-\infty}^{\infty}\tilde{g}_{\smash{\tilde{f}}_{\beta}(\mathbf{X};% \mathbf{m},\bm{\Theta}_{\beta})}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{\bm{% \theta}})-\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},{\bm{\theta}}_{% \alpha}))\tilde{g}_{\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{% \Theta}_{\alpha})}(\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},{\bm{% \theta}}_{\alpha}))\,\textrm{d}\mathbf{X},\end{split}start_ROW start_CELL end_CELL start_CELL over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∫ start_POSTSUBSCRIPT - ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) - over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) ) over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) ) d bold_X , end_CELL end_ROW (A.31)
=𝒩(f~(𝐗;𝐦,𝜽);𝝁(𝐗,𝐦,𝜽α,f~α),𝚺(𝐗,𝐦,𝜽β,𝐒β))g~f~α(𝐗;𝐦,𝚯α)(f~α(𝐗;𝐦,𝜽α))d𝐗,absentsuperscriptsubscript𝒩~𝑓𝐗𝐦𝜽𝝁𝐗𝐦subscript𝜽𝛼subscript~𝑓𝛼𝚺𝐗𝐦subscript𝜽𝛽subscript𝐒𝛽subscript~𝑔subscript~𝑓𝛼𝐗𝐦subscript𝚯𝛼subscript~𝑓𝛼𝐗𝐦subscript𝜽𝛼d𝐗\displaystyle=\int_{-\infty}^{\infty}\mathcal{N}(\smash{\tilde{f}}(\mathbf{X};% \mathbf{m},{\bm{\theta}})\,;\bm{\mu}(\mathbf{X},\mathbf{m},{\bm{\theta}}_{% \alpha},\smash{\tilde{f}}_{\alpha}),\bm{\Sigma}(\mathbf{X},\mathbf{m},{\bm{% \theta}}_{\beta},\mathbf{S}_{\beta}))\tilde{g}_{\smash{\tilde{f}}_{\alpha}(% \mathbf{X};\mathbf{m},\bm{\Theta}_{\alpha})}(\smash{\tilde{f}}_{\alpha}(% \mathbf{X};\mathbf{m},{\bm{\theta}}_{\alpha}))\,\textrm{d}\mathbf{X},= ∫ start_POSTSUBSCRIPT - ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT caligraphic_N ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) ; bold_italic_μ ( bold_X , bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) , bold_Σ ( bold_X , bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) ) over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) ) d bold_X , (A.32)

with

𝝁(𝐗;𝐦,𝜽α,f~α)=f(𝐗;𝐦)+f~α(𝐗;𝐦,𝜽α)𝝁𝐗𝐦subscript𝜽𝛼subscript~𝑓𝛼𝑓𝐗𝐦subscript~𝑓𝛼𝐗𝐦subscript𝜽𝛼\displaystyle\SwapAboveDisplaySkip\bm{\mu}(\mathbf{X};\mathbf{m},{\bm{\theta}}% _{\alpha},\smash{\tilde{f}}_{\alpha})=f(\mathbf{X};\mathbf{m})+\smash{\tilde{f% }}_{\alpha}(\mathbf{X};\mathbf{m},{\bm{\theta}}_{\alpha})bold_italic_μ ( bold_X ; bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) = italic_f ( bold_X ; bold_m ) + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) (A.33)

and

𝚺(𝐗;𝐦,𝐒β,𝒥β)=𝒥β(𝐗;𝐦)𝐒β𝒥β(𝐗;𝐦),𝚺𝐗𝐦subscript𝐒𝛽subscript𝒥𝛽subscript𝒥𝛽𝐗𝐦subscript𝐒𝛽subscript𝒥𝛽superscript𝐗𝐦top\displaystyle\SwapAboveDisplaySkip\bm{\Sigma}(\mathbf{X};\mathbf{m},\mathbf{S}% _{\beta},\mathcal{J}_{\beta})=\mathcal{J}_{\beta}(\mathbf{X};\mathbf{m})% \mathbf{S}_{\beta}{\mathcal{J}_{\beta}(\mathbf{X};\mathbf{m})}^{\top},bold_Σ ( bold_X ; bold_m , bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT , caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) = caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m ) bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , (A.34)

where we have used the fact that for a Gaussian distribution with mean m𝑚mitalic_m and covariance S𝑆Sitalic_S, 𝒩(zy;m,S)=𝒩(z;m+y,S)𝒩𝑧𝑦𝑚𝑆𝒩𝑧𝑚𝑦𝑆\mathcal{N}(z-y;m,S)=\mathcal{N}(z;m+y,S)caligraphic_N ( italic_z - italic_y ; italic_m , italic_S ) = caligraphic_N ( italic_z ; italic_m + italic_y , italic_S ). We can then approximate the probability density function g~f~(𝐗;𝐦,𝚯)(f~(𝐗;𝜽))subscript~𝑔~𝑓𝐗𝐦𝚯~𝑓𝐗𝜽\tilde{g}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})}(\smash{\tilde% {f}}(\mathbf{X};{\bm{\theta}}))over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_italic_θ ) ) via the Monte Carlo estimator

g~^f~(𝐗;𝐦,𝚯)(f~(𝐗;𝐦,𝜽))=1Rj=1R𝒩(f~(𝐗;𝐦,𝜽);𝝁(𝐗,𝐦,f~α(𝐗;𝐦,𝜽α)(j)),𝚺(𝐗;𝐦,𝐒β,𝒥β))subscript^~𝑔~𝑓𝐗𝐦𝚯~𝑓𝐗𝐦𝜽1𝑅superscriptsubscript𝑗1𝑅𝒩~𝑓𝐗𝐦𝜽𝝁𝐗𝐦subscript~𝑓𝛼superscript𝐗𝐦subscript𝜽𝛼𝑗𝚺𝐗𝐦subscript𝐒𝛽subscript𝒥𝛽\displaystyle\begin{split}&\hat{\tilde{g}}_{\smash{\tilde{f}}(\mathbf{X};% \mathbf{m},\bm{\Theta})}(\smash{\tilde{f}}(\mathbf{X};\mathbf{m},{\bm{\theta}}% ))\\ &=\frac{1}{R}\sum\nolimits_{j=1}^{R}\mathcal{N}(\smash{\tilde{f}}(\mathbf{X};% \mathbf{m},{\bm{\theta}})\,;\bm{\mu}(\mathbf{X},\mathbf{m},\smash{\tilde{f}}_{% \alpha}(\mathbf{X};\mathbf{m},{\bm{\theta}}_{\alpha})^{(j)}),\bm{\Sigma}(% \mathbf{X};\mathbf{m},\mathbf{S}_{\beta},\mathcal{J}_{\beta}))\end{split}start_ROW start_CELL end_CELL start_CELL over^ start_ARG over~ start_ARG italic_g end_ARG end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG italic_R end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT caligraphic_N ( over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_italic_θ ) ; bold_italic_μ ( bold_X , bold_m , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ) , bold_Σ ( bold_X ; bold_m , bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT , caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) ) end_CELL end_ROW (A.35)

with f~α(𝐗;𝐦,𝜽α)(j)g~f~α(𝐗;𝐦,𝚯α)similar-tosubscript~𝑓𝛼superscript𝐗𝐦subscript𝜽𝛼𝑗subscript~𝑔subscript~𝑓𝛼𝐗𝐦subscript𝚯𝛼\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},{\bm{\theta}}_{\alpha})^{(j)}% \sim\tilde{g}_{\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{% \alpha})}over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ∼ over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT. Finally, we can express the distribution g~^f~(𝐗;𝐦,𝚯)subscript^~𝑔~𝑓𝐗𝐦𝚯\hat{\tilde{g}}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{\Theta})}over^ start_ARG over~ start_ARG italic_g end_ARG end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT as

g~^f~(𝐗;𝐦,𝚯)=1Rj=1R𝒩(f(𝐗;𝐦)+f~α(𝐗;𝐦,𝚯α)(j),𝒥β(𝐗;𝐦)𝐒β𝒥β(𝐗;𝐦)),subscript^~𝑔~𝑓𝐗𝐦𝚯1𝑅superscriptsubscript𝑗1𝑅𝒩𝑓𝐗𝐦subscript~𝑓𝛼superscript𝐗𝐦subscript𝚯𝛼𝑗subscript𝒥𝛽𝐗𝐦subscript𝐒𝛽subscript𝒥𝛽superscript𝐗𝐦top\displaystyle\hat{\tilde{g}}_{\smash{\tilde{f}}(\mathbf{X};\mathbf{m},\bm{% \Theta})}=\frac{1}{R}\sum\nolimits_{j=1}^{R}\mathcal{N}\Big{(}f(\mathbf{X};% \mathbf{m})+\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{% \alpha})^{(j)},\mathcal{J}_{\beta}(\mathbf{X};\mathbf{m})\mathbf{S}_{\beta}{% \mathcal{J}_{\beta}(\mathbf{X};\mathbf{m})}^{\top}\Big{)},over^ start_ARG over~ start_ARG italic_g end_ARG end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_f end_ARG ( bold_X ; bold_m , bold_Θ ) end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_R end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT caligraphic_N ( italic_f ( bold_X ; bold_m ) + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT , caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m ) bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT caligraphic_J start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ( bold_X ; bold_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) , (A.36)

where g𝚯β=𝒩(𝐦β,𝐒β)subscript𝑔subscript𝚯𝛽𝒩subscript𝐦𝛽subscript𝐒𝛽g_{\bm{\Theta}_{\beta}}=\mathcal{N}(\mathbf{m}_{\beta},\mathbf{S}_{\beta})italic_g start_POSTSUBSCRIPT bold_Θ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( bold_m start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) and samples f~α(𝐗;𝐦,𝚯α)(j)subscript~𝑓𝛼superscript𝐗𝐦subscript𝚯𝛼𝑗\smash{\tilde{f}}_{\alpha}(\mathbf{X};\mathbf{m},\bm{\Theta}_{\alpha})^{(j)}over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_X ; bold_m , bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT are obtained by sampling parameters from the distribution g𝚯α=𝒩(𝐦α,𝐒α)subscript𝑔subscript𝚯𝛼𝒩subscript𝐦𝛼subscript𝐒𝛼g_{\bm{\Theta}_{\alpha}}=\mathcal{N}(\mathbf{m}_{\alpha},\mathbf{S}_{\alpha})italic_g start_POSTSUBSCRIPT bold_Θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( bold_m start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ). This concludes the proof. ∎

Appendix B Further Empirical Results

B.1 Tabular Results for Diabetic Retinopathy Diagnosis Tasks

The results below were reproduced from Band et al. [2021] using the retina benchmark.

Table 3: Country Shift. Prediction and uncertainty quality of baseline methods in terms of the area under the receiver operating characteristic curve (AUC) and classification accuracy, as a function of the proportion of data referred to a medical expert. All methods are tuned on in-domain validation AUC, and ensembles have K=3𝐾3K=3italic_K = 3 constituent models (true for all subsequent tables unless specified otherwise). On in-domain data, mc dropout performs best across all thresholds. On distributionally shifted data, no method consistently performs best.
No Referral 50%percent5050\%50 % Data Referred 70%percent7070\%70 % Data Referred
Method AUC (%) normal-↑\uparrow Accuracy (%) normal-↑\uparrow AUC (%) normal-↑\uparrow Accuracy (%) normal-↑\uparrow AUC (%) normal-↑\uparrow Accuracy normal-↑\uparrow
EyePACS Dataset (In-Domain)
map (Deterministic) 87.4±1.3plus-or-minus87.41.387.4{\scriptstyle\pm 1.3}87.4 ± 1.3 88.6±0.7plus-or-minus88.60.788.6{\scriptstyle\pm 0.7}88.6 ± 0.7 91.1±1.8plus-or-minus91.11.891.1{\scriptstyle\pm 1.8}91.1 ± 1.8 95.9±0.4plus-or-minus95.90.495.9{\scriptstyle\pm 0.4}95.9 ± 0.4 94.9±1.1plus-or-minus94.91.194.9{\scriptstyle\pm 1.1}94.9 ± 1.1 96.5±0.3plus-or-minus96.50.396.5{\scriptstyle\pm 0.3}96.5 ± 0.3
mfvi 83.3±0.2plus-or-minus83.30.283.3{\scriptstyle\pm 0.2}83.3 ± 0.2 85.7±0.1plus-or-minus85.70.185.7{\scriptstyle\pm 0.1}85.7 ± 0.1 85.5±0.7plus-or-minus85.50.785.5{\scriptstyle\pm 0.7}85.5 ± 0.7 94.5±0.1plus-or-minus94.50.194.5{\scriptstyle\pm 0.1}94.5 ± 0.1 88.2±0.7plus-or-minus88.20.788.2{\scriptstyle\pm 0.7}88.2 ± 0.7 95.9±0.1plus-or-minus95.90.195.9{\scriptstyle\pm 0.1}95.9 ± 0.1
radial-mfvi 83.2±0.5plus-or-minus83.20.583.2{\scriptstyle\pm 0.5}83.2 ± 0.5 74.2±5.0plus-or-minus74.25.074.2{\scriptstyle\pm 5.0}74.2 ± 5.0 88.9±0.9plus-or-minus88.90.988.9{\scriptstyle\pm 0.9}88.9 ± 0.9 81.8±6.0plus-or-minus81.86.081.8{\scriptstyle\pm 6.0}81.8 ± 6.0 91.2±1.3plus-or-minus91.21.391.2{\scriptstyle\pm 1.3}91.2 ± 1.3 83.8±5.5plus-or-minus83.85.583.8{\scriptstyle\pm 5.5}83.8 ± 5.5
fsvi 88.5±0.1plus-or-minus88.50.188.5{\scriptstyle\pm 0.1}88.5 ± 0.1 89.8±0.0plus-or-minus89.80.089.8{\scriptstyle\pm 0.0}89.8 ± 0.0 91.0±0.4plus-or-minus91.00.491.0{\scriptstyle\pm 0.4}91.0 ± 0.4 96.4±0.0plus-or-minus96.40.096.4{\scriptstyle\pm 0.0}96.4 ± 0.0 94.3±0.3plus-or-minus94.30.394.3{\scriptstyle\pm 0.3}94.3 ± 0.3 97.2±0.1plus-or-minus97.20.197.2{\scriptstyle\pm 0.1}97.2 ± 0.1
mc dropout 91.4±0.2plus-or-minus91.40.291.4{\scriptstyle\pm 0.2}91.4 ± 0.2 90.9±0.1plus-or-minus90.90.190.9{\scriptstyle\pm 0.1}90.9 ± 0.1 95.3±0.2plus-or-minus95.30.295.3{\scriptstyle\pm 0.2}95.3 ± 0.2 97.4±0.1plus-or-minus97.40.197.4{\scriptstyle\pm 0.1}97.4 ± 0.1 97.4±0.1plus-or-minus97.40.197.4{\scriptstyle\pm 0.1}97.4 ± 0.1 98.1±0.0plus-or-minus98.10.098.1{\scriptstyle\pm 0.0}98.1 ± 0.0
rank-1 85.6±1.4plus-or-minus85.61.485.6{\scriptstyle\pm 1.4}85.6 ± 1.4 87.7±0.8plus-or-minus87.70.887.7{\scriptstyle\pm 0.8}87.7 ± 0.8 87.1±2.3plus-or-minus87.12.387.1{\scriptstyle\pm 2.3}87.1 ± 2.3 95.3±0.5plus-or-minus95.30.595.3{\scriptstyle\pm 0.5}95.3 ± 0.5 90.9±2.0plus-or-minus90.92.090.9{\scriptstyle\pm 2.0}90.9 ± 2.0 96.4±0.4plus-or-minus96.40.496.4{\scriptstyle\pm 0.4}96.4 ± 0.4
deep ensemble 90.3±0.2plus-or-minus90.30.290.3{\scriptstyle\pm 0.2}90.3 ± 0.2 90.3±0.3plus-or-minus90.30.390.3{\scriptstyle\pm 0.3}90.3 ± 0.3 91.7±0.6plus-or-minus91.70.691.7{\scriptstyle\pm 0.6}91.7 ± 0.6 97.2±0.0plus-or-minus97.20.097.2{\scriptstyle\pm 0.0}97.2 ± 0.0 95.0±0.5plus-or-minus95.00.595.0{\scriptstyle\pm 0.5}95.0 ± 0.5 97.9±0.0plus-or-minus97.90.097.9{\scriptstyle\pm 0.0}97.9 ± 0.0
mfvi ensemble 85.4±0.0plus-or-minus85.40.085.4{\scriptstyle\pm 0.0}85.4 ± 0.0 87.8±0.0plus-or-minus87.80.087.8{\scriptstyle\pm 0.0}87.8 ± 0.0 86.3±0.4plus-or-minus86.30.486.3{\scriptstyle\pm 0.4}86.3 ± 0.4 95.4±0.0plus-or-minus95.40.095.4{\scriptstyle\pm 0.0}95.4 ± 0.0 89.2±0.4plus-or-minus89.20.489.2{\scriptstyle\pm 0.4}89.2 ± 0.4 96.7±0.1plus-or-minus96.70.196.7{\scriptstyle\pm 0.1}96.7 ± 0.1
radial-mfvi ensemble 84.9±0.1plus-or-minus84.90.184.9{\scriptstyle\pm 0.1}84.9 ± 0.1 74.2±1.5plus-or-minus74.21.574.2{\scriptstyle\pm 1.5}74.2 ± 1.5 91.4±0.2plus-or-minus91.40.291.4{\scriptstyle\pm 0.2}91.4 ± 0.2 83.4±1.7plus-or-minus83.41.783.4{\scriptstyle\pm 1.7}83.4 ± 1.7 93.3±0.3plus-or-minus93.30.393.3{\scriptstyle\pm 0.3}93.3 ± 0.3 85.9±1.6plus-or-minus85.91.685.9{\scriptstyle\pm 1.6}85.9 ± 1.6
fsvi ensemble 90.3±0.1plus-or-minus90.30.190.3{\scriptstyle\pm 0.1}90.3 ± 0.1 90.6±0.0plus-or-minus90.60.090.6{\scriptstyle\pm 0.0}90.6 ± 0.0 92.1±0.2plus-or-minus92.10.292.1{\scriptstyle\pm 0.2}92.1 ± 0.2 97.1±0.0plus-or-minus97.10.097.1{\scriptstyle\pm 0.0}97.1 ± 0.0 95.2±0.2plus-or-minus95.20.295.2{\scriptstyle\pm 0.2}95.2 ± 0.2 97.8±0.1plus-or-minus97.80.197.8{\scriptstyle\pm 0.1}97.8 ± 0.1
mc dropout ensemble 92.5±0.0plus-or-minus92.50.0\mathbf{92.5{\scriptstyle\pm 0.0}}bold_92.5 ± bold_0.0 91.6±0.0plus-or-minus91.60.0\mathbf{91.6{\scriptstyle\pm 0.0}}bold_91.6 ± bold_0.0 95.8±0.1plus-or-minus95.80.1\mathbf{95.8{\scriptstyle\pm 0.1}}bold_95.8 ± bold_0.1 97.8±0.0plus-or-minus97.80.0\mathbf{97.8{\scriptstyle\pm 0.0}}bold_97.8 ± bold_0.0 97.7±0.1plus-or-minus97.70.1\mathbf{97.7{\scriptstyle\pm 0.1}}bold_97.7 ± bold_0.1 98.4±0.0plus-or-minus98.40.0\mathbf{98.4{\scriptstyle\pm 0.0}}bold_98.4 ± bold_0.0
rank-1 ensemble 89.5±0.8plus-or-minus89.50.889.5{\scriptstyle\pm 0.8}89.5 ± 0.8 89.3±0.4plus-or-minus89.30.489.3{\scriptstyle\pm 0.4}89.3 ± 0.4 88.5±1.3plus-or-minus88.51.388.5{\scriptstyle\pm 1.3}88.5 ± 1.3 96.9±0.3plus-or-minus96.90.396.9{\scriptstyle\pm 0.3}96.9 ± 0.3 91.6±1.2plus-or-minus91.61.291.6{\scriptstyle\pm 1.2}91.6 ± 1.2 97.6±0.3plus-or-minus97.60.397.6{\scriptstyle\pm 0.3}97.6 ± 0.3
APTOS 2019 Dataset (Population Shift)
map (Deterministic) 92.2±0.2plus-or-minus92.20.292.2{\scriptstyle\pm 0.2}92.2 ± 0.2 86.2±0.6plus-or-minus86.20.686.2{\scriptstyle\pm 0.6}86.2 ± 0.6 80.1±3.6plus-or-minus80.13.680.1{\scriptstyle\pm 3.6}80.1 ± 3.6 87.6±1.5plus-or-minus87.61.587.6{\scriptstyle\pm 1.5}87.6 ± 1.5 55.4±4.3plus-or-minus55.44.355.4{\scriptstyle\pm 4.3}55.4 ± 4.3 85.4±1.2plus-or-minus85.41.285.4{\scriptstyle\pm 1.2}85.4 ± 1.2
mfvi 91.4±0.2plus-or-minus91.40.291.4{\scriptstyle\pm 0.2}91.4 ± 0.2 84.1±0.3plus-or-minus84.10.384.1{\scriptstyle\pm 0.3}84.1 ± 0.3 93.8±0.4plus-or-minus93.80.493.8{\scriptstyle\pm 0.4}93.8 ± 0.4 92.1±0.5plus-or-minus92.10.592.1{\scriptstyle\pm 0.5}92.1 ± 0.5 93.0±0.6plus-or-minus93.00.693.0{\scriptstyle\pm 0.6}93.0 ± 0.6 92.7±0.5plus-or-minus92.70.592.7{\scriptstyle\pm 0.5}92.7 ± 0.5
radial-mfvi 90.7±0.7plus-or-minus90.70.790.7{\scriptstyle\pm 0.7}90.7 ± 0.7 71.8±4.6plus-or-minus71.84.671.8{\scriptstyle\pm 4.6}71.8 ± 4.6 82.0±2.5plus-or-minus82.02.582.0{\scriptstyle\pm 2.5}82.0 ± 2.5 81.5±2.7plus-or-minus81.52.781.5{\scriptstyle\pm 2.7}81.5 ± 2.7 66.4±2.1plus-or-minus66.42.166.4{\scriptstyle\pm 2.1}66.4 ± 2.1 85.9±1.0plus-or-minus85.91.085.9{\scriptstyle\pm 1.0}85.9 ± 1.0
fsvi 94.1±0.1plus-or-minus94.10.194.1{\scriptstyle\pm 0.1}94.1 ± 0.1 87.6±0.5plus-or-minus87.60.587.6{\scriptstyle\pm 0.5}87.6 ± 0.5 90.6±0.9plus-or-minus90.60.990.6{\scriptstyle\pm 0.9}90.6 ± 0.9 90.7±0.7plus-or-minus90.70.790.7{\scriptstyle\pm 0.7}90.7 ± 0.7 77.2±4.6plus-or-minus77.24.677.2{\scriptstyle\pm 4.6}77.2 ± 4.6 89.8±0.3plus-or-minus89.80.389.8{\scriptstyle\pm 0.3}89.8 ± 0.3
mc dropout 94.0±0.2plus-or-minus94.00.294.0{\scriptstyle\pm 0.2}94.0 ± 0.2 86.8±0.2plus-or-minus86.80.286.8{\scriptstyle\pm 0.2}86.8 ± 0.2 87.4±0.3plus-or-minus87.40.387.4{\scriptstyle\pm 0.3}87.4 ± 0.3 88.1±0.2plus-or-minus88.10.288.1{\scriptstyle\pm 0.2}88.1 ± 0.2 65.3±1.7plus-or-minus65.31.765.3{\scriptstyle\pm 1.7}65.3 ± 1.7 88.2±0.4plus-or-minus88.20.488.2{\scriptstyle\pm 0.4}88.2 ± 0.4
rank-1 92.5±0.3plus-or-minus92.50.392.5{\scriptstyle\pm 0.3}92.5 ± 0.3 86.2±0.5plus-or-minus86.20.586.2{\scriptstyle\pm 0.5}86.2 ± 0.5 90.1±2.5plus-or-minus90.12.590.1{\scriptstyle\pm 2.5}90.1 ± 2.5 91.4±1.1plus-or-minus91.41.191.4{\scriptstyle\pm 1.1}91.4 ± 1.1 75.1±7.8plus-or-minus75.17.875.1{\scriptstyle\pm 7.8}75.1 ± 7.8 89.5±1.5plus-or-minus89.51.589.5{\scriptstyle\pm 1.5}89.5 ± 1.5
deep ensemble 94.2±0.2plus-or-minus94.20.294.2{\scriptstyle\pm 0.2}94.2 ± 0.2 87.5±0.1plus-or-minus87.50.187.5{\scriptstyle\pm 0.1}87.5 ± 0.1 91.2±1.9plus-or-minus91.21.991.2{\scriptstyle\pm 1.9}91.2 ± 1.9 92.4±0.9plus-or-minus92.40.992.4{\scriptstyle\pm 0.9}92.4 ± 0.9 67.4±7.3plus-or-minus67.47.367.4{\scriptstyle\pm 7.3}67.4 ± 7.3 90.1±1.2plus-or-minus90.11.290.1{\scriptstyle\pm 1.2}90.1 ± 1.2
mfvi ensemble 93.2±0.1plus-or-minus93.20.193.2{\scriptstyle\pm 0.1}93.2 ± 0.1 87.0±0.2plus-or-minus87.00.287.0{\scriptstyle\pm 0.2}87.0 ± 0.2 94.9±0.3plus-or-minus94.90.3\mathbf{94.9{\scriptstyle\pm 0.3}}bold_94.9 ± bold_0.3 93.7±0.3plus-or-minus93.70.3\mathbf{93.7{\scriptstyle\pm 0.3}}bold_93.7 ± bold_0.3 94.2±0.3plus-or-minus94.20.3\mathbf{94.2{\scriptstyle\pm 0.3}}bold_94.2 ± bold_0.3 94.0±0.3plus-or-minus94.00.3\mathbf{94.0{\scriptstyle\pm 0.3}}bold_94.0 ± bold_0.3
radial-mfvi ensemble 91.8±0.2plus-or-minus91.80.291.8{\scriptstyle\pm 0.2}91.8 ± 0.2 69.0±1.9plus-or-minus69.01.969.0{\scriptstyle\pm 1.9}69.0 ± 1.9 78.6±0.6plus-or-minus78.60.678.6{\scriptstyle\pm 0.6}78.6 ± 0.6 79.8±0.9plus-or-minus79.80.979.8{\scriptstyle\pm 0.9}79.8 ± 0.9 60.9±0.3plus-or-minus60.90.360.9{\scriptstyle\pm 0.3}60.9 ± 0.3 86.7±0.2plus-or-minus86.70.286.7{\scriptstyle\pm 0.2}86.7 ± 0.2
fsvi ensemble 94.6±0.1plus-or-minus94.60.1\mathbf{94.6{\scriptstyle\pm 0.1}}bold_94.6 ± bold_0.1 88.9±0.2plus-or-minus88.90.2\mathbf{88.9{\scriptstyle\pm 0.2}}bold_88.9 ± bold_0.2 90.7±0.5plus-or-minus90.70.590.7{\scriptstyle\pm 0.5}90.7 ± 0.5 91.1±0.6plus-or-minus91.10.691.1{\scriptstyle\pm 0.6}91.1 ± 0.6 74.1±3.4plus-or-minus74.13.474.1{\scriptstyle\pm 3.4}74.1 ± 3.4 89.8±0.2plus-or-minus89.80.289.8{\scriptstyle\pm 0.2}89.8 ± 0.2
mc dropout ensemble 94.1±0.1plus-or-minus94.10.194.1{\scriptstyle\pm 0.1}94.1 ± 0.1 87.6±0.1plus-or-minus87.60.187.6{\scriptstyle\pm 0.1}87.6 ± 0.1 86.8±0.2plus-or-minus86.80.286.8{\scriptstyle\pm 0.2}86.8 ± 0.2 88.0±0.2plus-or-minus88.00.288.0{\scriptstyle\pm 0.2}88.0 ± 0.2 62.3±0.4plus-or-minus62.30.462.3{\scriptstyle\pm 0.4}62.3 ± 0.4 87.7±0.2plus-or-minus87.70.287.7{\scriptstyle\pm 0.2}87.7 ± 0.2
rank-1 ensemble 94.1±0.2plus-or-minus94.10.294.1{\scriptstyle\pm 0.2}94.1 ± 0.2 88.3±0.2plus-or-minus88.30.288.3{\scriptstyle\pm 0.2}88.3 ± 0.2 94.9±0.4plus-or-minus94.90.4\mathbf{94.9{\scriptstyle\pm 0.4}}bold_94.9 ± bold_0.4 93.5±0.3plus-or-minus93.50.393.5{\scriptstyle\pm 0.3}93.5 ± 0.3 92.4±1.5plus-or-minus92.41.592.4{\scriptstyle\pm 1.5}92.4 ± 1.5 93.8±0.3plus-or-minus93.80.393.8{\scriptstyle\pm 0.3}93.8 ± 0.3

B.2 UCI Regression

Table 4: This table compares the predictive performance between the method proposed in this paper and the method proposed by Sun et al. [2019] on six datasets from the UCI database. We followed the same training protocol as Sun et al. [2019] and used the code provided by the authors to load and process the data. The same network architecture was used (one hidden layer with 50 hidden units). We report the results for the best set of hyperparameters, computed over ten random seeds. Lower RMSE and higher log-likelihood are better. Best results are shaded in gray. The first five rows are small-scale UCI experiments, and the sixth row (“Protein”) is a larger-scale experiment (45,740 data points).
RMSE Log-Likelihood
Sun et al. [2019] Ours Sun et al. [2019] Ours
Boston 2.378±0.104plus-or-minus2.3780.104\mathbf{2.378\pm 0.104}bold_2.378 ± bold_0.104   3.632±0.515plus-or-minus3.6320.5153.632\pm 0.5153.632 ± 0.515 2.301±0.038plus-or-minus2.3010.038\mathbf{-2.301\pm 0.038}- bold_2.301 ± bold_0.038   3.150±0.495plus-or-minus3.1500.495-3.150\pm 0.495- 3.150 ± 0.495
Concrete 4.935±0.180plus-or-minus4.9350.1804.935\pm 0.1804.935 ± 0.180 4.177±0.443plus-or-minus4.1770.443\mathbf{4.177\pm 0.443}bold_4.177 ± bold_0.443 3.096±0.016plus-or-minus3.0960.016-3.096\pm 0.016- 3.096 ± 0.016 2.855±0.116plus-or-minus2.8550.116\mathbf{-2.855\pm 0.116}- bold_2.855 ± bold_0.116
Energy 0.412±0.017plus-or-minus0.4120.0170.412\pm 0.0170.412 ± 0.017 0.409±0.060plus-or-minus0.4090.060\mathbf{0.409\pm 0.060}bold_0.409 ± bold_0.060 0.684±0.020plus-or-minus0.6840.020-0.684\pm 0.020- 0.684 ± 0.020 0.539±0.138plus-or-minus0.5390.138\mathbf{-0.539\pm 0.138}- bold_0.539 ± bold_0.138
Wine 0.673±0.014plus-or-minus0.6730.0140.673\pm 0.0140.673 ± 0.014 0.615±0.033plus-or-minus0.6150.033\mathbf{0.615\pm 0.033}bold_0.615 ± bold_0.033 1.040±0.013plus-or-minus1.0400.013-1.040\pm 0.013- 1.040 ± 0.013 0.959±0.034plus-or-minus0.9590.034\mathbf{-0.959\pm 0.034}- bold_0.959 ± bold_0.034
Yacht 0.607±0.068plus-or-minus0.6070.0680.607\pm 0.0680.607 ± 0.068 0.514±0.242plus-or-minus0.5140.242\mathbf{0.514\pm 0.242}bold_0.514 ± bold_0.242 1.033±0.033plus-or-minus1.0330.033-1.033\pm 0.033- 1.033 ± 0.033 0.888±0.334plus-or-minus0.8880.334\mathbf{-0.888\pm 0.334}- bold_0.888 ± bold_0.334
Protein 4.326±0.019plus-or-minus4.3260.0194.326\pm 0.0194.326 ± 0.019 4.248±0.043plus-or-minus4.2480.043\mathbf{4.248\pm 0.043}bold_4.248 ± bold_0.043 2.892±0.004plus-or-minus2.8920.004-2.892\pm 0.004- 2.892 ± 0.004 2.866±0.009plus-or-minus2.8660.009\mathbf{-2.866\pm 0.009}- bold_2.866 ± bold_0.009

Appendix C Illustrative Examples

C.1 Two Moons Classification Task

Refer to caption
(a) fsvi: Posterior Predictive Mean
Refer to caption
(b) fsvi: Posterior Predictive Variance
Refer to caption
(c) mfvi: Posterior Predictive Mean
Refer to caption
(d) mfvi: Posterior Predictive Variance
Refer to caption
(e) map Ensemble: Predictive Mean
Refer to caption
(f) map Ensemble: Predictive Variance
Figure 5: Binary classification on the Two Moons dataset. The plots show the posterior predictive mean and variance of a bnn trained via fsvi (Figure 5a and Figure 5b), of a bnn trained via mfvi (Figure 5c and Figure 5d), and an ensemble of map models (Figure 5e and Figure 5f). The predictive means represent the expected class probabilities and the predictive variance the model’s epistemic uncertainty over the class probabilities. With fsvi, the predictive distribution is able to faithfully capture the geometry of the data manifold and exhibits high uncertainty over the class probabilities in areas of the data space of which the data is not informative. In contrast, neither mfvi, nor map ensembles are unable to accurately capture the geometry of the data manifold only exhibit high uncertainty around the decision boundary.

C.2 Synthetic 1D Regression Datasets

Refer to caption Refer to caption
(a) “Snelson” Dataset (Snelson and Ghahramani [2006])
Refer to caption Refer to caption
(b) “OAT-1D” Dataset (van Amersfoort et al. [2021])
Refer to caption Refer to caption
(c) “Subspace Inference” Dataset (Izmailov et al. [2020])
Figure 6: 1D Regression with fsvi on a selection of datasets used to demonstrate desirable predictive uncertainty estimates in prior works. The left column is zoomed in.

Appendix D Implementation, Training, and Evaluation Details

D.1 Hyperparameter Selection Protocol

For fsvi, we used a holdout validation set (10% of the training set) to conduct a hyperparameter search over the prior variance, the number of context points used to evaluate the KL divergence, the context distribution, and the number of Monte Carlo samples used to evaluate the expected log-likelihood. We selected the set of hyperparameters that yielded the highest validation log-likelihood for all experiments. We state the hyperparameters selected for the different datasets below.

For other methods, we used a holdout validation set of the same size and selected the best-performing hyperparameters. We used implementations provided by the authors of mfvi (radial) and swag. All other methods were implemented from scratch unless stated otherwise.

D.2 FashionMNIST vs. MNIST/NotMNIST

We train all model on the FashionMNIST dataset and evaluate the models’ predictive uncertainty performance on out-of-distribution data on the MNIST dataset. Both datasets consist of images of size 28×28282828\times 2828 × 28 pixels. The FashionMNIST dataset is normalized to have zero mean and a standard deviation of one. The MNIST dataset is normalized with the same transformation, that is, using the same mean and standard deviation used for the in-distribution data. We chose FashionMNIST/MNIST instead of MNIST/NotMNIST because the latter is notably easier than the former.

In this experiment, a network architecture with two convolutional layers of 32 and 64 3×3333\times 33 × 3 filters and a fully-connected final layer of 128 hidden units is used. A max pooling operation is placed after each convolutional layer and ReLU activations are used. We do not use batch normalization. All models are trained for 30 epochs with a mini-batch size of 128 using SGD with a learning rate of 5×1035superscript1035\times 10^{-3}5 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT, momentum (with momentum parameter 0.9), and a cosine learning rate schedule with parameter 0.050.050.050.05.

For fsvi with p𝐗𝒞=subscript𝑝subscript𝐗𝒞absentp_{\mathbf{X}_{\mathcal{C}}}=italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT =random monochrome, we sampled 50% of the context points for each gradient step from the mini-batch and the other 50% according to the method described in Section D.8. For fsvi with p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT= KNIST, we used the KMNIST dataset.

D.3 CIFAR-10 vs. SVHN

We train all model on the CIFAR-10 dataset and evaluate the models’ predictive uncertainty performance on out-of-distribution data on the SVHN dataset. Both datasets consist of images of size 32×32×33232332\times 32\times 332 × 32 × 3, with RBG channels. The CIFAR-10 dataset is normalized to have zero mean and a standard deviation of one. The SVHN dataset is normalized with the same transformation, that is, using the same mean and standard deviation used for the in-distribution data. The training data is augmented with random horizontal flips (with a probability of 0.5) and random crops (4 zero pixels on all sides).

In this experiment, a standard ResNet-18 network architecture was used. All models are trained for 200 epochs with a mini-batch size of 128 using SGD with a learning rate of 5×1035superscript1035\times 10^{-3}5 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT, momentum (with momentum parameter 0.9), and a cosine learning rate schedule with parameter 0.050.050.050.05.

For fsvi with p𝐗𝒞=subscript𝑝subscript𝐗𝒞absentp_{\mathbf{X}_{\mathcal{C}}}=italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT =random monochrome, we sampled 100% of the context points for each gradient step from the mini-batch and the other 50% according to the method described in Section D.8. For fsvi with p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT= CIFAR-100, we used the CIFAR-100 dataset.

D.4 Diabetic Retinopathy Diagnosis

Prediction and Expert Referral.

In real-world settings where the evaluation data may be sampled from a shifted distribution, incorrect predictions may become increasingly likely. To account for that possibility, predictive uncertainty estimates can be used to identify datapoints where the likelihood of an incorrect prediction is particularly high and refer them for further review. We consider a corresponding selective prediction task, where the predictive performance of a given model is evaluated for varying expert referral rates. That is, for a given referral rate of γ[0,1]𝛾01\gamma\in[0,1]italic_γ ∈ [ 0 , 1 ], a model’s predictive uncertainty is used to identify the γ𝛾\gammaitalic_γ proportion of images in the evaluation set for which the model’s predictions are most uncertain. Those images are referred to a medical professional for further review, and the model is assessed on its predictions on the remaining (1γ)1𝛾(1-\gamma)( 1 - italic_γ ) proportion of images. By repeating this process for all possible referral rates and assessing the model’s predictive performance on the retained images, we estimate how reliable it would be in a safety-critical downstream task, where predictive uncertainty estimates are used in conjunction with human expertise to avoid harmful predictions. Importantly, selective prediction tolerates out-of-distribution examples. For example, even if unfamiliar features appear in certain images, a model with reliable uncertainty estimates will perform better in selective prediction by assigning these images high epistemic (and predictive) uncertainty, therefore referring them to an expert at a lower γ𝛾\gammaitalic_γ.

For all methods, experiments are performed using a ResNet-50 network architecture. Training and evaluation scripts as well as model checkpoints can be found at

github.com/google/uncertainty-baselines/.../diabetic_retinopathy_detection.

D.5 Two Moons

In this experiment, we use a multi-layer perceptron (MLP) consisting of two fully-connected layers with 30 hidden units each and tanh activations. We train all models with a learning rate of 103superscript10310^{-3}10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT.

For fsvi, we sampled context points uniformly from [10,10]×[10,10]10101010[-10,10]\times[-10,10][ - 10 , 10 ] × [ - 10 , 10 ].

D.6 1D Regression

In this experiment, we use a multi-layer perceptron (MLP) consisting of two fully-connected layers with 100 hidden units each and ReLU activations.

For fsvi, we sampled context points uniformly from [10,10]1010[-10,10][ - 10 , 10 ].

D.7 Further Implementation Details

We use the Adam optimizer with default settings of β1=0.9subscript𝛽10.9\beta_{1}=0.9italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0.9, β2=0.99subscript𝛽20.99\beta_{2}=0.99italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0.99 and ϵ=108italic-ϵsuperscript108\epsilon=10^{-8}italic_ϵ = 10 start_POSTSUPERSCRIPT - 8 end_POSTSUPERSCRIPT for all experiments. The deterministic neural networks that were used for the ensemble were trained with a weight decay of λ𝜆\lambdaitalic_λ = 1e-1. mfvi (tempered) was trained with a KL scaling factor of 0.1 to obtain a cold posterior.

D.8 Selection of Context Distribution

We estimate the supremum at every gradient step by sampling a set of context points 𝐗𝒞subscript𝐗𝒞\mathbf{X}_{\mathcal{C}}bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT from a distribution p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT at every gradient step. For tasks with image inputs, we construct a distribution p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT, defined as a uniform distribution over images with monochromatic channels. To generate a sample from this “monochrome images” distribution, we first take all images in the training data, flatten each channel, and stack the flattened image channels into a single vector each. We then draw a random element (i.e., a pixel) from each channel vector and then use these pixels to generate a monochrome image of a given resolution by setting every channel equal to the value of the pixel that was drawn. For regression tasks with a D𝐷Ditalic_D-dimensional input space, p𝐗𝒞subscript𝑝subscript𝐗𝒞p_{\mathbf{X}_{\mathcal{C}}}italic_p start_POSTSUBSCRIPT bold_X start_POSTSUBSCRIPT caligraphic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT is defined as a uniform distribution with lower and upper bounds set to the empirical lower and upper bounds of the training data. For further details on the effect of different sampling schemes on the posterior predictive distribution’s performance, see Appendix B.

D.9 Compute Resources

All experiments were carried out on an Nvidia V-100 GPU with 32GB of memory.