Files
deepmind-research/avae
Nimrod Gileadi 3f0d5ed1a0 Internal change
PiperOrigin-RevId: 417034217
2022-02-16 16:08:46 +00:00
..
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00
2022-02-16 16:08:46 +00:00

The Autoencoding Variational Autoencoder

This is the code for the models in NeurIPS Submission AVAE

Folder contains code to train AVAE model in JAX, and we will be uploading evaluation setup soon.

Code files in the folder

  • checkpointer.py: Checkpointing abstraction
  • data_iterators.py: Datasets to be used
  • decoders.py: VAE decoder network architectures
  • encoders.py: VAE encoder network architectures
  • kl.py: KL computation between 2 gaussians
  • train.py: Function to train given ELBO, network and data
  • train_main.py: Main file to train AVAE
  • vae.py: VAE model defining various ELBOs

Setup

To set up a Python3 virtual environment with the required dependencies, run:

python -m venv avae_env
source avae_env/bin/activate
pip install --upgrade pip
pip install -r avae/requirements.txt

Running AVAE training

Following command will run AVAE training for ColorMnist dataset using MLP network architectures.

python -m avae.train_main \
  --dataset='color_mnist' \
  --latent_dim=64 \
  --checkpoint_dir='/tmp/avae_checkpoints' \
  --checkpoint_filename='color_mnist_mlp_avae' \
  --rho=0.975 \
  --encoder='color_mnist_mlp_encoder' \
  --decoder='color_mnist_mlp_decoder'

References

Citing our work

If you use that code for your research, please consider citing our paper:

@article{cemgil2020autoencoding,
  title={The Autoencoding Variational Autoencoder},
  author={Cemgil, Taylan and Ghaisas, Sumedh and Dvijotham, Krishnamurthy and Gowal, Sven and Kohli, Pushmeet},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}