August 14, 2019

242 words 2 mins read

iShohei220/torch-gqn

iShohei220/torch-gqn

PyTorch Implementation of Generative Query Network

repo name iShohei220/torch-gqn
repo link https://github.com/iShohei220/torch-gqn
homepage
language Python
size (curr.) 55 kB
stars (curr.) 105
created 2018-12-04
license

PyTorch implementation of Generative Query Network

Original Paper: Neural scene representation and rendering (Eslami, et al., 2018)

https://deepmind.com/blog/neural-scene-representation-and-rendering

img

Pixyz Implementation: https://github.com/masa-su/pixyzoo/tree/master/GQN

Requirements

  • Python >=3.6
  • PyTorch
  • TensorBoardX

How to Train

python train.py --train_data_dir /path/to/dataset/train --test_data_dir /path/to/dataset/test

# Using multiple GPUs.
python train.py --device_ids 0 1 2 3 --train_data_dir /path/to/dataset/train --test_data_dir /path/to/dataset/test

❗️❗️CAUTION❗️❗️

This implementation needs very high computational power because of enormous number of parameters.

The default setting is based on the original GQN paper (In the paper, 4x GPUs with 24GB memory are used in the experiments). If you have only limited GPU memory, I recommend to use the option of --shared_core True (default: False) or --layers 8 (default 12) to reduce parameters. As far as I experimented, this change would not affect the quality of results so much, although the setting would be different with the original paper.

https://github.com/iShohei220/torch-gqn/issues/1

Dataset

https://github.com/deepmind/gqn-datasets

Usage

dataset/convert2torch.py

Convert TFRecords of the dataset for PyTorch implementation.

representation.py

Representation networks (See Figure S1 in Supplementary Materials of the paper).

core.py

Core networks of inference and generation (See Figure S2 in Supplementary Materials of the paper).

conv_lstm.py

Implementation of convolutional LSTM used in core.py.

gqn_dataset.py

Dataset class.

model.py

Main module of Generative Query Network.

train.py

Training algorithm.

scheduler.py

Scheduler of learning rate used in train.py.

Results (WIP)

Ground Truth Generation
Shepard-Metzler objects shepard_ground_truth shepard_generation
Mazes mazes_ground_truth mazes_generation
comments powered by Disqus