This GitHub repository contains the code for the reproducible experiments presented in our paper MMD Aggregated Two-Sample Test.
We provide the code to run the experiments to generate Figures 1-10 and Table 2 from our paper, those can be found in media. The code for the Failing Loudly experiment (with results reported in Table 1) can be found on the FL-MMDAgg repository.
To use our MMDAgg test in practice, we recommend using our mmdagg
package, more details available on the mmdagg repository.
Our implementation uses two quantile estimation methods (wild bootstrap and permutations). The MMDAgg test aggregates over different types of kernels (e.g. Gaussian, Laplace, Inverse Multi-Quadric (IMQ), Matérn (with various parameters) kernels), each with several bandwidths. In practice, we recommend aggregating over both Gaussian and Laplace kernels, each with 10 bandwidths.
python 3.9
The packages in requirements.txt are required to run our tests and the ones we compare against.
Additionally, the jax
and jaxlib
packages are required to run the Jax implementation of MMDAgg in mmdagg/jax.py.
In a chosen directory, clone the repository and change to its directory by executing
git clone git@github.com:antoninschrab/mmdagg-paper.git
cd mmdagg-paper
We then recommend creating and activating a virtual environment by either
- using
venv
:python3 -m venv mmdagg-env source mmdagg-env/bin/activate # can be deactivated by running: # deactivate
- or using
conda
:conda create --name mmdagg-env python=3.9 conda activate mmdagg-env # can be deactivated by running: # conda deactivate
The packages required for reproducibility of the experiments can then be installed in the virtual environment by running
python -m pip install -r requirements.txt
For using the Jax implementation of MMDAgg, Jax needs to be installed (instructions). For example, this can be done by running
- for GPU:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # conda install -c conda-forge -c nvidia pip numpy scipy cuda-nvcc "jaxlib=0.4.1=*cuda*" jax
- or, for CPU:
conda install -c conda-forge -c nvidia pip jaxlib=0.4.1 jax
To run the experiments, the following command can be executed
python experiments.py
This command saves the results in dedicated .csv
and .pkl
files in a new directory user/raw
.
The output of this command is already provided in paper/raw.
The results of the rest of the experiments, saved in the results directory, can be obtained by running the Computations_mmdagg.ipynb notebook and the Computations_autotst.ipynb notebook which uses the autotst package introduced in the AutoML Two-Sample Test paper.
The actual figures of the paper can be obtained from the saved results by running the code in the figures.ipynb notebook.
All the experiments are comprised of 'embarrassingly parallel for loops', significant speed up can be obtained by using parallel computing libraries such as joblib
or dask
.
Half of the experiments uses a down-sampled version of the MNIST dataset which is created as a .data
file in a new directory mnist_dataset
when running the script experiments.py.
This dataset can also be generated on its own by executing
python mnist.py
The other half of the experiments uses samples drawn from a perturbed uniform density (Eq. 17).
A rejection sampler f_theta_sampler
for this density is implemented in sampling.py.
The MMDAgg test is implemented as the function mmdagg
in mmdagg/np.py for the Numpy version and in mmdagg/jax.py for the Jax version.
For the Numpy implementation of our MMDAgg test, we only require the numpy
and scipy
packages.
For the Jax implementation of our MMDAgg test, we only require the jax
and jaxlib
packages.
To use our tests in practice, we recommend using our mmdagg
package which is available on the mmdagg repository. It can be installed by running
pip install git+https://github.com/antoninschrab/mmdagg.git
Installation instructions and example code are available on the mmdagg repository.
We also provide some code showing how to use our MMDAgg test in the demo_speed.ipynb notebook which also contains speed comparisons between the Jax and Numpy implementations, as reported below.
Speed in s | Numpy (CPU) | Jax (CPU) | Jax (GPU) |
---|---|---|---|
MMDAgg | 43.1 | 14.9 | 0.495 |
In practice, we recommend using the Jax implementation as it runs considerably faster (100 times faster in the above table, see notebook demo_speed.ipynb).
Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift. Stephan Rabanser, Stephan Günnemann, Zachary C. Lipto. (paper, code)
Learning Kernel Tests Without Data Splitting. Jonas M. Kübler, Wittawat Jitkrittum, Bernhard Schölkopf, Krikamol Muandet. (paper, code)
AutoML Two-Sample Test. Jonas M. Kübler, Vincent Stimper, Simon Buchholz, Krikamol Muandet, Bernhard Schölkopf. (paper, code)
For a computationally efficient version of MMDAgg which can run in linear time, check out our paper Efficient Aggregated Kernel Tests using Incomplete U-statistics with reproducible experiments in the agginc-paper repository and a package in the agginc repository.
If you have any issues running our code, please do not hesitate to contact Antonin Schrab.
Centre for Artificial Intelligence, Department of Computer Science, University College London
Gatsby Computational Neuroscience Unit, University College London
Inria London
@article{schrab2021mmd,
author = {Antonin Schrab and Ilmun Kim and M{\'e}lisande Albert and B{\'e}atrice Laurent and Benjamin Guedj and Arthur Gretton},
title = {{MMD} Aggregated Two-Sample Test},
journal = {Journal of Machine Learning Research},
year = {2023},
volume = {24},
number = {194},
pages = {1--81},
url = {http://jmlr.org/papers/v24/21-1289.html}
}
MIT License (see LICENSE.md).