EmilienDupont/augmented-neural-odes
Pytorch implementation of Augmented Neural ODEs :sunflower:
repo name | EmilienDupont/augmented-neural-odes |
repo link | https://github.com/EmilienDupont/augmented-neural-odes |
homepage | |
language | Jupyter Notebook |
size (curr.) | 5844 kB |
stars (curr.) | 302 |
created | 2019-05-19 |
license | MIT License |
Augmented Neural ODEs
This repo contains code for the paper Augmented Neural ODEs (2019).
Examples
Requirements
The requirements that can be directly installed from PyPi can be found in requirements.txt
. This code also builds on the awesome torchdiffeq
library, which provides various ODE solvers on GPU. Instructions for installing torchdiffeq
can be found in this repo.
Usage
The usage pattern is simple:
# ... Load some data ...
import torch
from anode.conv_models import ConvODENet
from anode.models import ODENet
from anode.training import Trainer
# Instantiate a model
# For regular data...
anode = ODENet(device, data_dim=2, hidden_dim=16, augment_dim=1)
# ... or for images
anode = ConvODENet(device, img_size=(1, 28, 28), num_filters=32, augment_dim=1)
# Instantiate an optimizer and a trainer
optimizer = torch.optim.Adam(anode.parameters(), lr=1e-3)
trainer = Trainer(anode, optimizer, device)
# Train model on your dataloader
trainer.train(dataloader, num_epochs=10)
More detailed examples and tutorials can be found in the augmented-neural-ode-example.ipynb
and vector-field-visualizations.ipynb
notebooks.
Running experiments
To run a large number of repeated experiments on toy datasets, use the following
python main_experiment.py config.json
where the specifications for the experiment can be found in config.json
. This will log all the information about the experiments and generate plots for losses, NFEs and so on.
Running experiments on image datasets
To run large experiments on image datasets, use the following
python main_experiment_img.py config_img.json
where the specifications for the experiment can be found in config_img.json
.
Demos
We also provide two demo notebooks that show how to reproduce some of the results and figures from the paper.
Vector fields
The vector-field-visualizations.ipynb
notebook contains a demo and tutorial for reproducing the experiments on 1D ODE flows in the paper.
Augmented Neural ODEs
The augmented-neural-ode-example.ipynb
notebook contains a demo and tutorial for reproducing the experiments comparing Neural ODEs and Augmented Neural ODEs on simple 2D functions.
Data
The MNIST and CIFAR10 datasets can be directly downloaded using torchvision
(this will happen automatically if you run the code, unless you already have those datasets downloaded). To run experiments on ImageNet, you will need to download the data from the Tiny ImageNet website.
Citing
If you find this code useful in your research, consider citing with
@article{dupont2019augmented,
title={Augmented Neural ODEs},
author={Dupont, Emilien and Doucet, Arnaud and Teh, Yee Whye},
journal={arXiv preprint arXiv:1904.01681},
year={2019}
}
License
MIT