Tractable Function-Space Variational Inference in
Bayesian Neural Networks
Abstract
Reliable predictive uncertainty estimation plays an important role in enabling the deployment of neural networks to safety-critical settings. A popular approach for estimating the predictive uncertainty of neural networks is to define a prior distribution over the network parameters, infer an approximate posterior distribution, and use it to make stochastic predictions. However, explicit inference over neural network parameters makes it difficult to incorporate meaningful prior information about the data-generating process into the model. In this paper, we pursue an alternative approach. Recognizing that the primary object of interest in most settings is the distribution over functions induced by the posterior distribution over neural network parameters, we frame Bayesian inference in neural networks explicitly as inferring a posterior distribution over functions and propose a scalable function-space variational inference method that allows incorporating prior information and results in reliable predictive uncertainty estimates. We show that the proposed method leads to state-of-the-art uncertainty estimation and predictive performance on a range of prediction tasks and demonstrate that it performs well on a challenging safety-critical medical diagnosis task in which reliable uncertainty estimation is essential.
1 Introduction
Machine learning models succeed at an increasingly wide range of narrowly defined tasks (Krizhevsky et al., 2012; Mnih et al., 2013; Silver et al., 2016; Jumper et al., 2021) but may fail without warning when used on inputs that are meaningfully different from the data they were trained on (Amodei et al., 2016; Hendrycks et al., 2021; Rudner and Toner, 2021a, b). To deploy machine learning models in safety-critical environments where failures are costly or may endanger human lives, machine learning methods must be reliable and have the ability to ‘fail gracefully.’ A promising tool for incorporating fail-safe mechanisms into machine learning systems, predictive uncertainty quantification allows machine learning models to express their confidence in the correctness of their predictions.
In this paper, we develop a method for obtaining reliable uncertainty estimates in Bayesian neural networks (bnns, Neal (1996)). While bnns have promised to combine the advantages of deep learning and Bayesian inference, existing approaches for approximate inference in bnns fall short of this promise and have been demonstrated to result in approximate posterior predictive distributions that underperform ‘non-Bayesian’ methods both in terms of predictive accuracy and uncertainty quantification—making them of limited use in practice (Ovadia et al., 2019; Foong et al., 2019; Farquhar et al., 2020a; Band et al., 2021). A potential reason for this shortcoming is that commonly used parameter-space inference methods make it difficult to define meaningful priors that effectively incorporate information about the data-generating process into inference.
To avoid this limitation, we follow Sun et al. (2019) and consider a variational objective defined explicitly in terms of distributions over functions induced by distributions over parameters. In contrast to prior works that rely on approximation techniques that prevent such function-space variational objectives to be used with high-dimensional inputs and highly-overparameterized neural networks, we propose a simple estimator of the Kullback-Leibler divergence between distributions over functions that enables us to perform stochastic variational inference. The proposed estimation procedure allows defining priors that explicitly encourage high predictive uncertainty away from the training data as well as priors that reflect relevant information about the task at hand.
We demonstrate that this approach leads to posterior approximations that exhibit significantly improved predictive uncertainty estimates compared to a wide array of state-of-the-art Bayesian and non-Bayesian methods. Figure 1 shows examples of predictive distributions obtained via function-space variational inference on low-dimensional, easy-to-visualize datasets. As can be seen in the figures, the predictive distributions fit the training data well while also exhibiting a high degree of predictive uncertainty in parts of the input space far away from the training data, as desired.
Contributions. We propose a simple estimation procedure for performing function-space variational inference in bnns. The variational method allows for the incorporation of meaningful prior information about the data-generating process into the inference and produces reliable predictive uncertainty estimates. We perform a thorough empirical evaluation in which we compare the proposed approach to a wide array of competing methods and show that it consistently results in high predictive performance and reliable predictive uncertainty estimates, outperforming other methods in terms of predictive accuracy, robustness to distribution shifts, and uncertainty-based detection of distributionally-shifted data samples. We evaluate the proposed method on standard benchmarking datasets as well as on a safety-critical medical diagnosis task in which reliable uncertainty estimation is essential.111Our code can be accessed at https://github.com/timrudner/FSVI.
2 Preliminaries
We consider supervised learning tasks on data with inputs and targets , where for regression and for classification tasks. Bayesian neural networks (bnns) are stochastic neural networks trained using (approximate) Bayesian inference. Denoting the parameters of such a stochastic neural network by the multivariate random variable and letting the function mapping defined by a neural network architecture be given by , we obtain a random function . For a parameter realization , we obtain a corresponding function realization, . When evaluated at a finite collection of points , is a multivariate random variable and is a vector.
Letting be a likelihood function and be the likelihood of observing the targets under the stochastic function evaluated at inputs and letting be a prior distribution over the stochastic network parameters , we can use Bayes’ Theorem to find the posterior distribution, (MacKay, 1992; Neal, 1996). However, since the mapping is a nonlinear function of the stochastic parameters , exact inference is analytically intractable. Variational inference is an approach that seeks to sidestep this intractability by framing posterior inference as a variational optimization problem, where the goal is to find a distribution in a variational family that solves the variational problem (Wainwright and Jordan, 2008). If is the family of mean-field Gaussian distributions and the prior distribution over parameters given by a diagonal Gaussian distribution, the resulting variational objective is amenable to stochastic variational inference and can be optimized using gradient-based methods (Hinton and van Camp, 1993; Graves, 2011; Hoffman et al., 2013; Blundell et al., 2015).
2.1 A Function-Space Perspective on Variational Inference in Bayesian Neural Networks
Instead of seeking to infer an approximate posterior distribution over parameters, we frame variational inference in stochastic neural networks as inferring an approximation to the posterior distribution over functions induced by the posterior distribution over parameters , that is,
(1) |
where is the Dirac delta function (Wolpert, 1993). Considering the prior distribution over functions induced by a prior distribution over parameters ,
(2) |
and the variational distribution over functions induced by a variational distribution over parameters ,
(3) |
we can express the problem of finding a posterior distribution over functions variationally as
(4) |
which allows us to effectively incorporate meaningful prior information about the underlying data-generating process into training. As discussed by Burt et al. (2021), this variational objective is guaranteed to be well-defined for suitably chosen prior distributions over functions. Specifically, the KL divergence between two distributions over functions generated from different distributions over parameters applied to the same mapping (e.g., the same neural network architecture) is well-defined (i.e., finite) if the KL divergence between the distributions over parameters is finite, since, by the strong data processing inequality (Polyanskiy and Wu, 2017),
(5) |
As a result, if , which is the case for finite-dimensional parameter vectors and absolutely continuous with respect to , then the function-space KL divergence is finite and thus well-defined as a variational objective.
Hence, for a likelihood function defined on a finite set of training targets and a suitably defined prior distribution over functions, we can express the variational problem above equivalently as the well-defined maximization problem with
(6) |
where is also a KL divergence between distributions over functions.
Unfortunately, evaluating the KL divergence in Equation 6 is in general intractable for arbitrary mappings . To obtain a tractable objective, Sun et al. (2019) showed that can be expressed as the supremum of the KL divergence from to over all finite sets of evaluation points, resulting in the objective function
(7) |
where is the collection of all finite sets of evaluation points. However, this objective function is still challenging to optimize in practice: The supremum cannot be obtained analytically and the KL divergence term itself is analytically intractable and difficult to estimate in high dimensions—even for a single evaluation point.
In the next section, we will describe an approximation and estimation procedure that allows scaling function-space variational inference to large neural networks and high-dimensional input data.
3 Deriving a Tractable Function-Space Variational Objective
The primary obstacle to computing the objective in Equation 6 is the KL divergence from to . There are two reasons why the KL divergence in Equation 7 is intractable: First, for bnns or other non-linear models, we do not have access to the probability density functions of the multivariate distributions and ; second, for all but extremely simple input spaces, we are unable to compute the supremum over all possible finite sets of evaluation points. In the remainder of this section, we outline an approach for obtaining an estimator of a locally accurate approximation to the KL divergence that allows for scalable gradient-based optimization of Equation 7.
We first approach the problem of computing the KL divergence between two bnns evaluated at a finite set of points. To do so, we first derive tractable approximations to the distributions over functions and Next, we show that under these approximations, we are able to obtain a closed-form approximation to the KL divergence and describe a simple Monte Carlo estimator of the supremum in the function-space KL divergence.
3.1 Approximating Distributions over Functions via Local Linearization
To obtain an approximation to the probability distributions of and , we use a first-order Taylor expansion of the mapping about the mean parameters of and , respectively, and derive the induced distributions under the linearized mapping.
For a stochastic function defined in terms of stochastic parameters distributed according to distribution with and , we denote the linearization of the stochastic function about by
(8) |
where is the Jacobian of evaluated at , and the mean and covariance of the distribution over the linearized mapping at are given by
(9) | ||||
(10) |
For a derivation of this result, see Appendix A. Since Gaussianity is preserved under affine transformations, if is a multivariate Gaussian distribution with mean and diagonal co-variance , then the distribution over is given by
(11) |
For stochastic functions parameterized by many millions of parameters, obtaining the covariance of —which requires computing an inner product of two Jacobian matrices—can be computationally expensive. Instead of computing the distribution over the linearized mapping exactly, we can construct a suitable Monte Carlo estimator. To do so, we consider a partition of the set of parameters into sets and (with ) and note that the linearized mapping can then be expressed as
(12) |
with
(13) |
where and are the columns of the Jacobian matrix corresponding to the sets of parameters and , respectively, and and are the corresponding random parameter vectors. Noting that Equation 12 expresses as a sum of (affine transformations of) random variables, we can use the fact that for independent Gaussian random variables and , the distribution of is equal to the convolution of the distributions and to obtain an approximation to . In particular, we can show that if is a multivariate Gaussian distribution with , the distribution can be approximated by the Monte Carlo estimator
(14) |
where and samples are obtained by sampling parameters from the distribution . For a derivation of this result, see Appendix A. This estimator is biased for finite but converges to as . Similarly, for finite , the smaller , the more accurate and less biased the estimator will be. In our empirical evaluation, we use a single Monte Carlo sample, , to preserve Gaussianity and choose to be the set of parameters in neural network layers and to be the set of parameters in the final neural network layer.
3.2 Approximating the Function-Space Kullback-Leibler Divergence
From Section 3.1, we know that if and are both Gaussian distributions, then the induced distributions under the linearized mapping evaluated at a finite set of evaluation points will be Gaussian as well. This means that for Gaussian variational and prior distributions over , we can obtain locally accurate approximations to the induced distributions to and use them to approximate the KL divergence in the variational objective by . Moreover, for an isotropic Gaussian prior and a mean-field Gaussian variational distribution, is a KL divergence between two multivariate Gaussians and can be obtained analytically.
Using this approximation, we obtain an estimator of the variational objective given by
(15) |
where the arguments of the KL divergence have been replaced by the (locally accurate) approximations to the variational and prior distributions over functions evaluated at , respectively. Since the stochastic functions induced by and under the linearized mapping will be closer to the stochastic function under the smaller the variance of and , respectively, the approximation to the KL divergence will be more accurate the smaller the variance of and .
Next, we turn to computing the supremum. Unlike Sun et al. (2019), who consider the supremum as a separate optimization problem, we do not seek to compute the supremum by searching over points but instead propose to estimate the supremum at every gradient step via a simple finite-sample estimator. Specifically, letting , we estimate using the Monte Carlo estimator
(16) |
where is a collection of sets of context points jointly sampled from a context distribution . Each context set can be viewed as a single Monte Carlo sample from the input space so that the estimator provides an -sample Monte Carlo estimate of the supremum. While this estimator is crude and only provides a rough approximation to the true supremum, it encourages the variational distribution over functions to match the prior distribution over functions on the sets of context points. The choice of the context distribution can be informed by knowledge about the prediction task and should be viewed as a problem-specific modeling choice. Similarly, the numbers of samples and are hyperparameters to be optimized with a validation set. For details on how is chosen for the empirical evaluation in Section 5, see Appendix D.
3.3 Stochastic Estimation of the Approximate Function-Space Variational Objective
Let be a Gaussian mean-field variational distribution, let be an isotropic Gaussian prior, let be a mini-batch of the training data, and reparameterize as . Using the estimator defined above and estimating the expected log-likelihood via Monte Carlo sampling, we obtain a Monte Carlo estimator for the function-space variational objective:
(17) |
with and as defined above. This Monte Carlo estimator is biased due to the linearization and context-set approximations but allows for scalable gradient-based stochastic optimization.
Selection of Prior. For all experiments that involve uncertainty quantification, we chose a prior distribution over parameters that induces a prior distribution over functions and a prior predictive distribution that exhibits a high degree of predictive uncertainty at evaluation points from regions in input space where has non-zero support and, under smoothness constraints, on evaluation points in nearby regions. For settings where prior information is encoded in data—for example, in the form of expert demonstrations of robotic manipulation tasks (Rudner et al., 2021) or in the form of pre-trained networks in continual or transfer learning (Rudner et al., 2022)—an empirical prior that reflects this information can be specified. For further details, see Appendix D.
Selection of Context Distribution. The distribution allows us to incorporate information about the data-generating process into training and encourage the variational distribution to match the prior over functions in relevant parts of the input space. By taking advantage of the abundance of data available in real-world settings, context distributions can be constructed from large datasets like ImageNet (Krizhevsky et al., 2012), from small but diverse datasets like CIFAR-100, or by using any set of task-related unlabeled data. In our experiments, we choose two types of context distributions. One of the context distributions is constructed from the training data and only contains randomly sampled monochrome images, and one is constructed from a real-world dataset generated from a data distribution related to that of the training data. For example, when training on FashionMNIST, we use KMNIST as the context distribution, and when training on CIFAR-10, we use CIFAR-100 as the context distribution. For further details, see Appendix D.
Posterior Predictive Distribution. After optimizing the variational objective with respect to the parameters of the variational distribution , we use the fact that we can obtain function draws by sampling from the distribution over parameters to obtain an approximate posterior predictive distribution
(18) |
where is the number of Monte Carlo samples used to estimate the predictive distribution.
4 Related Work
There is a growing body of work on function-space approaches to inference in bnns, deep learning, and applications such as continual learning (Benjamin et al., 2019; Sun et al., 2019; Titsias et al., 2020; Burt et al., 2021; Pan et al., 2020; Ma and Hernández-Lobato, 2021; Rudner et al., 2022).
Function-Space Inference in Bayesian Neural Networks.
Previously proposed methods for fsvi in bnns are based on approximate gradient estimators and either replace the supremum in Equation 7 with an expectation (Sun et al., 2019) or do not define an explicit variational objective (Wang et al., 2019). Sun et al. (2019) and Carvalho et al. (2020) use Gaussian process priors over functions for which the function-space variational inference problem is not well-defined (see Section 2.1 and Burt et al. (2021)). More recent work has attempted to circumvent the intractability of the variational objective in Equation 6 by proposing alternative objectives for function-space inference in bnns (Ma et al., 2019; Ober and Aitchison, 2020; Ma and Hernández-Lobato, 2021). Rudner et al. (2022) extend the approach presented in Section 3 to sequential inference problems and apply it to continual learning.
Linear Models.
Immer et al. (2020) and Khan et al. (2019) show that approximate bnn posterior distribution via the Laplace and Generalized-Gauss-Newton approximation corresponds to exact posteriors under linearizations of different models. Unlike in our approach, they use a Laplace approximation and do not perform variational inference and do not optimize the variance parameters. Furthermore, Immer et al. (2020) and Khan et al. (2019) use a neural network model to obtain a parameter maximum a posteriori estimate, but then use a linearization of the neural network model to compute a posterior predictive distribution. In contrast, our work only uses the linearization to obtain an estimator of the variational objective but uses the unlinearized model to construct a posterior predictive distribution.
Pathologies of Variational Inference in Bayesian Neural Networks.
Burt et al. (2021) consider the function-space variational objective in Equation 6 and show that the KL divergence between bnns with different networks architectures are not well-defined. A parallel line of research showed that posterior predictive distributions of shallow bnns with mean-field variational distributions have a limited ability to represent complex covariance structures in function space (Foong et al., 2019, 2020) but that deep bnns do not suffer from this limitation (Farquhar et al., 2020b). Our results are consistent with the findings of Farquhar et al. (2020b) that mean-field variational distributions are able to represent complex covariance structures in function space.
5 Empirical Evaluation
In this section, we evaluate fsvi on high-dimensional classification tasks that were out of reach for function-space variational inference methods proposed in prior works and compare fsvi to several well-established and state-of-the-art Bayesian deep learning and deterministic uncertainty quantification methods. We show that fsvi (sometimes significantly) outperforms existing Bayesian and non-Bayesian methods in terms of their in-distribution uncertainty calibration and out-of-distribution predictive uncertainty estimation. For a details on models, training and validation procedures, and datasets used, see Appendix D. For a comparison to Sun et al. (2019) on small-scale regression tasks, see Section B.2.
5.1 Predictive Performance, Uncertainty Estimation, and Distribution Shift Detection
In this set of experiments, we assess the reliability of the uncertainty estimates generated by fsvi. If a bnn trained via fsvi is able to perform reliable uncertainty estimation, its predictive uncertainty will be significantly higher on input points that were generated according to a different data-generating distribution than the training data. For models trained on the FashionMNIST dataset, we use the MNIST and NotMNIST datasets as out-of-distribution evaluation points, while for models trained on the CIFAR-10 dataset, we use the SVHN dataset as out-of-distribution evaluation points.
For models trained on either FashionMNIST or CIFAR-10, we evaluate their in-distribution performance in terms of test accuracy, test log-likelihood, and test calibration. To evaluate the quality of different models’ uncertainty estimates, we compute uncertainty estimates for the pairs FashionMNIST/MNIST, FashionMNIST/NotMNIST, and CIFAR-10/SVHN to and measure for a range of thresholds how well the datasets in each pair can be separated solely based on the uncertainty estimates. This experiment setup follows prior work by van Amersfoort et al. (2020) and Immer et al. (2020). We report the area under the receiver operating characteristic (ROC) curve in Tables 1 and 2.
Method | Accuracy | ECE | AUROC M | AUROC NM |
---|---|---|---|---|
map | 91.73 | 0.037 | 87.00 | 74.85 |
mfvi (Blundell et al., 2015) | 91.03 | 0.038 | 93.10 | 88.88 |
mfvi (tempered) | 91.38 | 0.058 | 86.30 | 80.78 |
mfvi (radial) (Farquhar et al., 2020a) | 90.31 | 0.035 | 84.40 | 82.11 |
mc dropout (Gal and Ghahramani, 2016) | 90.55 | 88.46 | 80.02 | |
swag (Maddox et al., 2019) | 92.56 | 0.043 | 85.18 | 80.31 |
duq (van Amersfoort et al., 2020) | 95.50 | 94.60 | ||
bnn-laplace (Immer et al., 2020) | 92.25 | 95.55 | ||
spg (Ma and Hernández-Lobato, 2021) | 91.60 | 95.60 | ||
fsvi ( = random monochrome) | ||||
fsvi ( = KMNIST) | ||||
Deep Ensemble | 92.49 | 89.22 | 83.17 | |
fsvi Ensemble ( = random monochrome) | 0.020 |
Predictive Performance and Calibration. To assess in-distribution predictive performance and calibration, we report the test accuracy, negative log-likelihood (NLL), and expected calibration error (ECE) for models trained on FashionMNIST and CIFAR-10 in Tables 1 and 2. On both FashionMNIST and CIFAR-10, fsvi achieves the lowest NLL and either the best or second-best predictive accuracy and ECE, respectively, across all methods. Notably, fsvi significantly outperforms spg (Ma and Hernández-Lobato, 2021), an alternative function-space variational inference method.
Predictive Uncertainty under Distribution Shift. In Tables 1 and 2, we report evaluation metrics that elucidate the reliability of different methods’ predictive uncertainty under distribution shift. fsvi exhibits reliable predictive uncertainty estimates that allow distinguishing between in- and out-of-distribution inputs with high accuracy. As would be expected, we observe that using context distributions that reflect our knowledge about the data-generating process can significantly improve uncertainty quantification under fsvi. For the FashionMNIST experiment, we used the KMNIST dataset, which contains grayscale images of Kuzushiji letters, and for the CIFAR-10 experiment, we used the CIFAR-100 dataset, which contains RGB images of 100 classes. Both KMNIST and CIFAR-100 differ from the OOD datasets (MNIST and NotMNIST and SVHN, respectively) used to compute OOD-AUROC metrics in Tables 1 and 2, but using them as context distributions significantly increased the ability of bnns trained via fsvi to identify distributionally shifted samples. Since the variational objective encourages matching the prior (which we chose to have high variance) on samples from the context distribution can improve uncertainty estimation in regions of the input space far from the training data.
Method | Accuracy | ECE | OOD-AUROC | C-CIFAR Acc |
---|---|---|---|---|
map | 93.19 | 0.043 | 94.65 | 78.87 |
mfvi (Blundell et al., 2015) | 89.98 | 0.040 | 92.14 | 79.36 |
mfvi (tempered) | 90.87 | 0.048 | 91.82 | 79.86 |
mc dropout (Gal and Ghahramani, 2016) | 93.55 | 0.040 | 92.44 | 80.13 |
swag (Maddox et al., 2019) | 93.13 | 0.067 | 89.79 | 76.12 |
vogn (Osawa et al., 2019) | 84.27 | 0.040 | 87.60 | |
duq (van Amersfoort et al., 2020) | 92.70 | |||
spg (Ma and Hernández-Lobato, 2021) | 77.69 | 88.30 | ||
fsvi ( = random monochrome) | 0.034 | |||
fsvi ( = CIFAR-100) | ||||
Deep Ensemble | 0.019 | |||
fsvi Ensemble ( = random monochrome) | 0.013 |
5.2 Generalization and Reliability of Predictive Uncertainty under Distribution Shift
To assess the reliability of predictive models in deep learning, Ovadia et al. (2019) propose the following desiderata: In order for a model to be considered reliable, it ought to (i) exhibit low predictive uncertainty on training data and high predictive uncertainty on out-of-distribution inputs, (ii) generate predictive uncertainty estimates that allow distinguishing in- from out-of-distribution inputs, and (iii) if possible, maintain high predictive accuracy even under distribution shift. Models that satisfy these desiderata are less likely to make poor, high-confidence predictions and more amenable for use in safety-critical downstream tasks.
To illustrate these desiderata, we follow Ovadia et al. (2019) and consider the rotated MNIST task, where a model is trained on MNIST and evaluated on rotated MNIST digits. The goal is to maintain a high level of predictive accuracy (measured in terms of Brier scores) while exhibiting an increasing level of predictive uncertainty on distribution shifts of increasing magnitude. Figure 2 shows Brier scores (lower is better) and predictive entropy estimates (higher means more uncertain) of four different models. As rotating the MNIST digits gradually shifts the data distributions, we would expect Brier scores to increase (corresponding to worse predictive accuracy) as the rotation angle increases. A model with reliable predictive entropy estimates would only experience a small decrease under distribution shift while exhibiting a large increase in predictive uncertainty. As can be seen in the plot, the Brier scores of fsvi decreases the least, while fsvi’s uncertainty is significantly higher than other models’. To assess the reliability of different uncertainty quantification methods on a more challenging distribution-shift task, we consider corrupted CIFAR-10 inputs under the second-mildest corruption level used in (Ovadia et al., 2019) and report our results in Table 2. Consistent with the rotated MNIST results, fsvi achieves the highest accuracy on the corrupted data.
5.3 Safety-Critical Uncertainty-Aware Selective Prediction: Diabetic Retinopathy Diagnosis
To evaluate the reliability of the predictive uncertainty of fsvi in a real-world safety-critical setting, we consider the task of diagnosing diabetic retinopathy (DR), a medical condition that can lead to impaired vision, from retina scans (Leibig et al., 2017; Filos et al., 2019; Band et al., 2021). We use two publicly available datasets, EyePACS (2015) and APTOS (2019), each containing RGB images of a human retina graded by a medical expert on the following scale: 0 (no DR), 1 (mild DR), 2 (moderate DR), 3 (severe DR), and 4 (proliferative DR). The Kaggle dataset was collected from patients in the United States, while the APTOS dataset was collected from patients in India using cheaper but more modern scanning devices. We follow Leibig et al. (2017), Filos et al. (2019), and Band et al. (2021) and binarize all examples from both the EyePACS and APTOS datasets by dividing the classes up into sight-threatening diabetic retinopathy—defined as moderate diabetic retinopathy or worse (classes )—and non-sight-threatening diabetic retinopathy—defined as no or mild diabetic retinopathy (classes ). This results in a binary prediction task.
To assess the reliability of predictive models when medical training and test data are obtained from different patient populations or collected with the same medical equipment, we follow Band et al. (2021) and use the Kaggle dataset for training and the distributionally shifted APTOS dataset for evaluation. The results are shown in Figure 4, which plot the ROC curves for the binary prediction problems as well as the area under the ROC curve for an uncertainty aware selective prediction task. For further details about the uncertainty-aware selective prediction evaluation protocol, see Section D.4. Figure 4 shows that fsvi performs well on all four tasks and is only outperformed by mc dropout. For full tabular results, see Section B.1.
6 Conclusion
The paper proposed a scalable and effective approach to function-space variational inference in bnns. We demonstrated that the proposed estimator of the function-space variational objective can be scaled up to high-dimensional data and large neural network architectures and that fsvi exhibits consistently reliable in- and out-of-distribution predictive performance on a wide range of datasets when compared to well-established and state-of-the-art uncertainty quantification methods. We hope that this work will lead to further research into function-space variational inference and the development of more sophisticated data-driven prior distributions over functions.
Acknowledgements
We thank Bryn Elesedy, Bobby He, and Andrew Jesson for feedback on an early draft of this paper. We thank Joost van Amersfoort for helpful discussions about experiment design and implementations. Tim G. J. Rudner is funded by the Rhodes Trust and the Engineering and Physical Sciences Research Council (EPSRC). We gratefully acknowledge donations of computing resources by the Alan Turing Institute.
References
- Amodei et al. (2016) Dario Amodei, Chris Olah, Jacob Steinhardt, Paul Christiano, John Schulman, and Dan Mané. Concrete problems in ai safety, 2016.
- APTOS (2019) APTOS. APTOS 2019 Blindness Detection Dataset, 2019.
- Band et al. (2021) Neil Band, Tim G. J. Rudner, Qixuan Feng, Angelos Filos, Zachary Nado, Michael W. Dusenberry, Ghassen Jerfel, Dustin Tran, and Yarin Gal. Benchmarking Bayesian Deep Learning on Diabetic Retinopathy Detection Tasks. 2021.
- Benjamin et al. (2019) Ari Benjamin, David Rolnick, and Konrad Kording. Measuring and regularizing networks in function space. In International Conference on Learning Representations, 2019.
- Blundell et al. (2015) Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty in neural networks. volume 37 of Proceedings of Machine Learning Research, pages 1613–1622, Lille, France, 07–09 Jul 2015. PMLR.
- Burt et al. (2021) David R. Burt, Sebastian W. Ober, Adrià Garriga-Alonso, and Mark van der Wilk. Understanding variational inference in function-space. In Third Symposium on Advances in Approximate Bayesian Inference, 2021.
- Carvalho et al. (2020) Eduardo D. C. Carvalho, Ronald Clark, Andrea Nicastro, and Paul H. J. Kelly. Scalable uncertainty for computer vision with functional variational inference. In IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2020.
- EyePACS (2015) EyePACS. Diabetic Retinopathy Detection Dataset, 2015.
- Farquhar et al. (2020a) Sebastian Farquhar, Michael A. Osborne, and Yarin Gal. Radial Bayesian neural networks: Beyond discrete support in large-scale Bayesian deep learning. In Silvia Chiappa and Roberto Calandra, editors, Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, volume 108 of Proceedings of Machine Learning Research, pages 1352–1362. PMLR, 26–28 Aug 2020a.
- Farquhar et al. (2020b) Sebastian Farquhar, Lewis Smith, and Yarin Gal. Liberty or depth: Deep Bayesian neural nets do not need complex weight posterior approximations. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020b.
- Filos et al. (2019) Angelos Filos, Sebastian Farquhar, Aidan N. Gomez, Tim G. J. Rudner, Zachary Kenton, Lewis Smith, Milad Alizadeh, Arnoud de Kroon, and Yarin Gal. A systematic comparison of Bayesian deep learning robustness in diabetic retinopathy tasks, 2019.
- Foong et al. (2019) Andrew Y. K. Foong, Yingzhen Li, José Miguel Hernández-Lobato, and Richard E. Turner. ’in-between’ uncertainty in Bayesian neural networks, 2019.
- Foong et al. (2020) Andrew Y. K. Foong, David R. Burt, Yingzhen Li, and Richard E. Turner. On the expressiveness of approximate inference in Bayesian neural networks. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020.
- Gal and Ghahramani (2016) Yarin Gal and Zoubin Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In Proceedings of the 33rd International Conference on International Conference on Machine Learning - Volume 48, ICML 2016, pages 1050–1059, 2016.
- Graves (2011) Alex Graves. Practical variational inference for neural networks. In Proceedings of the 24th International Conference on Neural Information Processing Systems, NIPS’11, page 2348–2356, Red Hook, NY, USA, 2011. Curran Associates Inc. ISBN 9781618395993.
- Hendrycks et al. (2021) Dan Hendrycks, Nicholas Carlini, John Schulman, and Jacob Steinhardt. Unsolved problems in ml safety, 2021.
- Hinton and van Camp (1993) Geoffrey E. Hinton and Drew van Camp. Keeping the neural networks simple by minimizing the description length of the weights. In Proceedings of the Sixth Annual Conference on Computational Learning Theory, COLT ’93, page 5–13, New York, NY, USA, 1993. Association for Computing Machinery. ISBN 0897916115.
- Hoffman et al. (2013) Matthew D. Hoffman, David M. Blei, Chong Wang, and John Paisley. Stochastic variational inference. Journal of Machine Learning Research, 14(1):1303–1347, May 2013. ISSN 1532-4435.
- Immer et al. (2020) Alexander Immer, Maciej Korzepa, and Matthias Bauer. Improving predictions of Bayesian neural networks via local linearization, 2020.
- Izmailov et al. (2020) Pavel Izmailov, Wesley J. Maddox, Polina Kirichenko, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Subspace inference for Bayesian deep learning. In Ryan P. Adams and Vibhav Gogate, editors, Proceedings of The 35th Uncertainty in Artificial Intelligence Conference, volume 115 of Proceedings of Machine Learning Research, pages 1169–1179. PMLR, 22–25 Jul 2020.
- Jumper et al. (2021) John Jumper, Richard Evans, Alexander Pritzel, Tim Green, Michael Figurnov, Olaf Ronneberger, Kathryn Tunyasuvunakool, Russ Bates, Augustin Zidek, Anna Potapenko, Alex Bridgland, Clemens Meyer, Simon A A Kohl, Andrew J Ballard, Andrew Cowie, Bernardino Romera-Paredes, Stanislav Nikolov, Rishub Jain, Jonas Adler, Trevor Back, Stig Petersen, David Reiman, Ellen Clancy, Michal Zielinski, Martin Steinegger, Michalina Pacholska, Tamas Berghammer, Sebastian Bodenstein, David Silver, Oriol Vinyals, Andrew W Senior, Koray Kavukcuoglu, Pushmeet Kohli, and Demis Hassabis. Highly accurate protein structure prediction with AlphaFold. Nature, 596(7873):583–589, 2021. doi: 10.1038/s41586-021-03819-2.
- Khan et al. (2019) Mohammad Emtiyaz E Khan, Alexander Immer, Ehsan Abedi, and Maciej Korzepa. Approximate inference turns deep networks into Gaussian processes. In Advances in Neural Information Processing Systems 32, pages 3094–3104. Curran Associates, Inc., 2019.
- Krizhevsky et al. (2012) Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems 25:, pages 1106–1114, 2012.
- Leibig et al. (2017) Christian Leibig, Vaneeda Allken, Murat Seçkin Ayhan, Philipp Berens, and Siegfried Wahl. Leveraging Uncertainty Information From Deep Neural Networks for Disease Detection. Nature Scientific Reports, 7(1):17816, 2017.
- Ma and Hernández-Lobato (2021) Chao Ma and José Miguel Hernández-Lobato. Functional variational inference based on stochastic process generators. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
- Ma et al. (2019) Chao Ma, Yingzhen Li, and Jose Miguel Hernandez-Lobato. Variational implicit processes. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 4222–4233. PMLR, 09–15 Jun 2019.
- MacKay (1992) David J. C. MacKay. A practical Bayesian framework for backpropagation networks. Neural Comput., 4(3):448–472, May 1992. ISSN 0899-7667. doi: 10.1162/neco.1992.4.3.448.
- Maddox et al. (2019) Wesley J Maddox, Pavel Izmailov, Timur Garipov, Dmitry P Vetrov, and Andrew Gordon Wilson. A simple baseline for Bayesian uncertainty in deep learning. In Advances in Neural Information Processing Systems, pages 13153–13164, 2019.
- Matthews et al. (2016) Alexander G. de G. Matthews, James Hensman, Richard Turner, and Zoubin Ghahramani. On sparse variational methods and the Kullback-Leibler divergence between stochastic processes. volume 51 of Proceedings of Machine Learning Research, pages 231–239, Cadiz, Spain, 09–11 May 2016. PMLR.
- Mnih et al. (2013) Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. Playing atari with deep reinforcement learning. In NIPS Deep Learning Workshop. 2013.
- Neal (1996) Radford M Neal. Bayesian Learning for Neural Networks. 1996.
- Ober and Aitchison (2020) Sebastian W. Ober and Laurence Aitchison. Global inducing point variational posteriors for Bayesian neural networks and deep Gaussian processes, 2020.
- Osawa et al. (2019) Kazuki Osawa, Siddharth Swaroop, Mohammad Emtiyaz E Khan, Anirudh Jain, Runa Eschenhagen, Richard E Turner, and Rio Yokota. Practical deep learning with Bayesian principles. In Advances in Neural Information Processing Systems, volume 32, pages 4287–4299. Curran Associates, Inc., 2019.
- Ovadia et al. (2019) Yaniv Ovadia, Emily Fertig, Jie Ren, Zachary Nado, D. Sculley, Sebastian Nowozin, Joshua Dillon, Balaji Lakshminarayanan, and Jasper Snoek. Can you trust your model’s uncertainty? Evaluating predictive uncertainty under dataset shift. In Advances in Neural Information Processing Systems 32. 2019.
- Pan et al. (2020) Pingbo Pan, Siddharth Swaroop, Alexander Immer, Runa Eschenhagen, Richard E. Turner, and Mohammad Emtiyaz Khan. Continual deep learning by functional regularisation of memorable past. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020.
- Polyanskiy and Wu (2017) Yury Polyanskiy and Yihong Wu. Strong data-processing inequalities for channels and Bayesian networks. In Eric Carlen, Mokshay Madiman, and Elisabeth M. Werner, editors, Convexity and Concentration, pages 211–249, New York, NY, 2017. Springer New York. ISBN 978-1-4939-7005-6.
- Rudner and Toner (2021a) Tim G. J. Rudner and Helen Toner. Key Concepts in AI Safety: An Overview. In CSET Issue Briefs, 2021a.
- Rudner and Toner (2021b) Tim G. J. Rudner and Helen Toner. Key Concepts in AI Safety: Robustness and Adversarial Examples. In CSET Issue Briefs, 2021b.
- Rudner et al. (2021) Tim G. J. Rudner, Cong Lu, Michael A. Osborne, Yarin Gal, and Yee Whye Teh. On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations. In Advances in Neural Information Processing Systems 34, 2021.
- Rudner et al. (2022) Tim G. J. Rudner, Freddie Bickford Smith, Qixuan Feng, Yee Whye Teh, and Yarin Gal. Continual Learning via Sequential Function-Space Variational Inference. In Proceedings of the 38th International Conference on Machine Learning, Proceedings of Machine Learning Research. PMLR, 2022.
- Schervish (1995) M. J. Schervish. Theory of Statistics. Springer-Verlag, New York, NY, 1995.
- Silver et al. (2016) David Silver, Aja Huang, Chris J. Maddison, Arthur Guez, Laurent Sifre, George van den Driessche, Julian Schrittwieser, Ioannis Antonoglou, Veda Panneershelvam, Marc Lanctot, Sander Dieleman, Dominik Grewe, John Nham, Nal Kalchbrenner, Ilya Sutskever, Timothy Lillicrap, Madeleine Leach, Koray Kavukcuoglu, Thore Graepel, and Demis Hassabis. Mastering the game of Go with deep neural networks and tree search. 529, 2016.
- Snelson and Ghahramani (2006) Edward Snelson and Zoubin Ghahramani. Sparse Gaussian processes using pseudo-inputs. In Y. Weiss, B. Schölkopf, and J. C. Platt, editors, Advances in Neural Information Processing Systems 18, pages 1257–1264. MIT Press, 2006.
- Sun et al. (2019) Shengyang Sun, Guodong Zhang, Jiaxin Shi, and Roger B. Grosse. Functional variational Bayesian neural networks. In 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net, 2019.
- Titsias et al. (2020) Michalis K. Titsias, Jonathan Schwarz, Alexander G. de G. Matthews, Razvan Pascanu, and Yee Whye Teh. Functional regularisation for continual learning with Gaussian processes. In International Conference on Learning Representations, 2020.
- van Amersfoort et al. (2020) Joost van Amersfoort, Lewis Smith, Yee Whye Teh, and Yarin Gal. Uncertainty estimation using a single deep deterministic neural network. In International Conference on Machine Learning, 2020.
- van Amersfoort et al. (2021) Joost van Amersfoort, Lewis Smith, Andrew Jesson, Oscar Key, and Yarin Gal. Variational deterministic uncertainty quantification, 2021.
- Wainwright and Jordan (2008) Martin J Wainwright and Michael I Jordan. Graphical Models, Exponential Families, and Variational Inference. Now Publishers Inc., Hanover, MA, USA, 2008. ISBN 1601981848.
- Wang et al. (2019) Ziyu Wang, Tongzheng Ren, Jun Zhu, and Bo Zhang. Function space particle optimization for Bayesian neural networks. In International Conference on Learning Representations, 2019.
- Widdowson (2016) D. T. S. Widdowson. The Management of Grading Quality: Good Practice in the Quality Assurance of Grading. Tech. Rep., 2016.
- Wolpert (1993) David H. Wolpert. Bayesian backpropagation over i-o functions rather than weights. In J. Cowan, G. Tesauro, and J. Alspector, editors, Advances in Neural Information Processing Systems, volume 6. Morgan-Kaufmann, 1993.
Appendix
Table of Contents
\startcontents[sections] \printcontents[sections]l1
Appendix A Proofs & Derivations
A.1 Function-Space Variational Objective
This proof follows steps from Matthews et al. [2016]. Consider measures and both of which define distributions over some function , indexed by an infinite index set . Let be a dataset and let denote a set of inputs and a set of targets. Consider the measure-theoretic version of Bayes’ Theorem [Schervish, 1995]:
(A.1) |
where is the likelihood and is the marginal likelihood. We assume that the likelihood function is evaluated at a finite subset of the index set . Denote by a projection function that takes a function and returns the same function, evaluated at a finite set of points , so we can write
(A.2) |
and similarly, the marginal likelihood becomes . Now, considering the measure-theoretic version of the KL divergence between an approximating stochastic process and a posterior stochastic process , we can write
(A.3) |
where is some prior stochastic process. Now, we can apply the measure-theoretic Bayes’ Theorem to obtain
(A.4) | ||||
(A.5) | ||||
(A.6) |
where is marginally consistent given the projection . Rearranging, we can get
(A.7) | ||||
(A.8) | ||||
(A.9) |
Finally, this lower bound can equivalently be expressed as
(A.10) |
where is an infinite index set excluding the finite index set , that is, , or by Theorem 1 in Sun et al. [2019], we can write
(A.11) |
where is the collection of all finite sets of evaluation points.
A.2 Distribution under Linearized Function Mapping
Proposition 1 (Distribution under Linearized Mapping).
For a stochastic function defined in terms of stochastic parameters distributed according to distribution with and , denote the linearization of the stochastic function about by
where is the Jacobian of evaluated at . Then the mean and co-variance of the distribution over the linearized mapping at are given by
Proof.
We wish to find and
(A.12) |
To see that , note that, by linearity of expectation, we have
(A.13) |
To see that , note that in general, for a multivariate random variable , , and hence,
(A.14) |
We already know that , so we only need to find :
(A.15) | ||||
(A.16) | ||||
(A.17) | ||||
(A.18) |
where the last line follows from the definition of . By definition of the covariance, we then obtain
(A.19) | ||||
(A.20) |
With this result, we obtain the covariance function
(A.21) | ||||
(A.22) | ||||
(A.23) | ||||
(A.24) |
Finally, yields . This concludes the proof. ∎
Proposition 2 (Approximate Distribution under Linearized Mapping).
For a stochastic function defined in terms of stochastic parameters distributed according to distribution , denote the linearization of the stochastic function about by
where is the Jacobian of evaluated at . Then, for a partition of the set of parameters into sets and , a distribution with , the distribution can be approximated via the Monte Carlo estimator
(A.25) |
where , , and
(A.26) |
with denoting the columns of the Jacobian matrix corresponding to the sets of parameters and for obtained by sampling parameters from the distribution .
Proof.
Consider a partition of the set of parameters into sets and and express the linearized mapping as
(A.27) |
with
(A.28) |
and
(A.29) |
where and are the columns of the Jacobian matrix corresponding to the sets of parameters and , respectively, and and are the corresponding random parameter vectors.
Noting that Equation A.27 expresses as a sum of (affine transformations of) random variables, we can use the fact that for independent Gaussian random variables and , the distribution of is equal to the convolution of the distributions and to obtain an approximation to . In particular, for ,
(A.30) |
Letting , , and , we can write
(A.31) | ||||
(A.32) |
with
(A.33) |
and
(A.34) |
where we have used the fact that for a Gaussian distribution with mean and covariance , . We can then approximate the probability density function via the Monte Carlo estimator
(A.35) |
with . Finally, we can express the distribution as
(A.36) |
where and samples are obtained by sampling parameters from the distribution . This concludes the proof. ∎
Appendix B Further Empirical Results
B.1 Tabular Results for Diabetic Retinopathy Diagnosis Tasks
The results below were reproduced from Band et al. [2021] using the retina benchmark.
No Referral | Data Referred | Data Referred | ||||
Method | AUC (%) | Accuracy (%) | AUC (%) | Accuracy (%) | AUC (%) | Accuracy |
EyePACS Dataset (In-Domain) | ||||||
map (Deterministic) | ||||||
mfvi | ||||||
radial-mfvi | ||||||
fsvi | ||||||
mc dropout | ||||||
rank-1 | ||||||
deep ensemble | ||||||
mfvi ensemble | ||||||
radial-mfvi ensemble | ||||||
fsvi ensemble | ||||||
mc dropout ensemble | ||||||
rank-1 ensemble | ||||||
APTOS 2019 Dataset (Population Shift) | ||||||
map (Deterministic) | ||||||
mfvi | ||||||
radial-mfvi | ||||||
fsvi | ||||||
mc dropout | ||||||
rank-1 | ||||||
deep ensemble | ||||||
mfvi ensemble | ||||||
radial-mfvi ensemble | ||||||
fsvi ensemble | ||||||
mc dropout ensemble | ||||||
rank-1 ensemble |
B.2 UCI Regression
RMSE | Log-Likelihood | |||
---|---|---|---|---|
Sun et al. [2019] | Ours | Sun et al. [2019] | Ours | |
Boston | ||||
Concrete | ||||
Energy | ||||
Wine | ||||
Yacht | ||||
Protein |
Appendix C Illustrative Examples
C.1 Two Moons Classification Task
C.2 Synthetic 1D Regression Datasets
Appendix D Implementation, Training, and Evaluation Details
D.1 Hyperparameter Selection Protocol
For fsvi, we used a holdout validation set (10% of the training set) to conduct a hyperparameter search over the prior variance, the number of context points used to evaluate the KL divergence, the context distribution, and the number of Monte Carlo samples used to evaluate the expected log-likelihood. We selected the set of hyperparameters that yielded the highest validation log-likelihood for all experiments. We state the hyperparameters selected for the different datasets below.
For other methods, we used a holdout validation set of the same size and selected the best-performing hyperparameters. We used implementations provided by the authors of mfvi (radial) and swag. All other methods were implemented from scratch unless stated otherwise.
D.2 FashionMNIST vs. MNIST/NotMNIST
We train all model on the FashionMNIST dataset and evaluate the models’ predictive uncertainty performance on out-of-distribution data on the MNIST dataset. Both datasets consist of images of size pixels. The FashionMNIST dataset is normalized to have zero mean and a standard deviation of one. The MNIST dataset is normalized with the same transformation, that is, using the same mean and standard deviation used for the in-distribution data. We chose FashionMNIST/MNIST instead of MNIST/NotMNIST because the latter is notably easier than the former.
In this experiment, a network architecture with two convolutional layers of 32 and 64 filters and a fully-connected final layer of 128 hidden units is used. A max pooling operation is placed after each convolutional layer and ReLU activations are used. We do not use batch normalization. All models are trained for 30 epochs with a mini-batch size of 128 using SGD with a learning rate of , momentum (with momentum parameter 0.9), and a cosine learning rate schedule with parameter .
For fsvi with random monochrome, we sampled 50% of the context points for each gradient step from the mini-batch and the other 50% according to the method described in Section D.8. For fsvi with = KNIST, we used the KMNIST dataset.
D.3 CIFAR-10 vs. SVHN
We train all model on the CIFAR-10 dataset and evaluate the models’ predictive uncertainty performance on out-of-distribution data on the SVHN dataset. Both datasets consist of images of size , with RBG channels. The CIFAR-10 dataset is normalized to have zero mean and a standard deviation of one. The SVHN dataset is normalized with the same transformation, that is, using the same mean and standard deviation used for the in-distribution data. The training data is augmented with random horizontal flips (with a probability of 0.5) and random crops (4 zero pixels on all sides).
In this experiment, a standard ResNet-18 network architecture was used. All models are trained for 200 epochs with a mini-batch size of 128 using SGD with a learning rate of , momentum (with momentum parameter 0.9), and a cosine learning rate schedule with parameter .
For fsvi with random monochrome, we sampled 100% of the context points for each gradient step from the mini-batch and the other 50% according to the method described in Section D.8. For fsvi with = CIFAR-100, we used the CIFAR-100 dataset.
D.4 Diabetic Retinopathy Diagnosis
Prediction and Expert Referral.
In real-world settings where the evaluation data may be sampled from a shifted distribution, incorrect predictions may become increasingly likely. To account for that possibility, predictive uncertainty estimates can be used to identify datapoints where the likelihood of an incorrect prediction is particularly high and refer them for further review. We consider a corresponding selective prediction task, where the predictive performance of a given model is evaluated for varying expert referral rates. That is, for a given referral rate of , a model’s predictive uncertainty is used to identify the proportion of images in the evaluation set for which the model’s predictions are most uncertain. Those images are referred to a medical professional for further review, and the model is assessed on its predictions on the remaining proportion of images. By repeating this process for all possible referral rates and assessing the model’s predictive performance on the retained images, we estimate how reliable it would be in a safety-critical downstream task, where predictive uncertainty estimates are used in conjunction with human expertise to avoid harmful predictions. Importantly, selective prediction tolerates out-of-distribution examples. For example, even if unfamiliar features appear in certain images, a model with reliable uncertainty estimates will perform better in selective prediction by assigning these images high epistemic (and predictive) uncertainty, therefore referring them to an expert at a lower .
For all methods, experiments are performed using a ResNet-50 network architecture. Training and evaluation scripts as well as model checkpoints can be found at
github.com/google/uncertainty-baselines/.../diabetic_retinopathy_detection.
D.5 Two Moons
In this experiment, we use a multi-layer perceptron (MLP) consisting of two fully-connected layers with 30 hidden units each and tanh activations. We train all models with a learning rate of .
For fsvi, we sampled context points uniformly from .
D.6 1D Regression
In this experiment, we use a multi-layer perceptron (MLP) consisting of two fully-connected layers with 100 hidden units each and ReLU activations.
For fsvi, we sampled context points uniformly from .
D.7 Further Implementation Details
We use the Adam optimizer with default settings of , and for all experiments. The deterministic neural networks that were used for the ensemble were trained with a weight decay of = 1e-1. mfvi (tempered) was trained with a KL scaling factor of 0.1 to obtain a cold posterior.
D.8 Selection of Context Distribution
We estimate the supremum at every gradient step by sampling a set of context points from a distribution at every gradient step. For tasks with image inputs, we construct a distribution , defined as a uniform distribution over images with monochromatic channels. To generate a sample from this “monochrome images” distribution, we first take all images in the training data, flatten each channel, and stack the flattened image channels into a single vector each. We then draw a random element (i.e., a pixel) from each channel vector and then use these pixels to generate a monochrome image of a given resolution by setting every channel equal to the value of the pixel that was drawn. For regression tasks with a -dimensional input space, is defined as a uniform distribution with lower and upper bounds set to the empirical lower and upper bounds of the training data. For further details on the effect of different sampling schemes on the posterior predictive distribution’s performance, see Appendix B.
D.9 Compute Resources
All experiments were carried out on an Nvidia V-100 GPU with 32GB of memory.