March 13, 2020

675 words 4 mins read

mit-han-lab/gan-compression

mit-han-lab/gan-compression

[CVPR 2020] GAN Compression: Efficient Architectures for Interactive Conditional GANs

repo name mit-han-lab/gan-compression
repo link https://github.com/mit-han-lab/gan-compression
homepage
language Python
size (curr.) 25133 kB
stars (curr.) 262
created 2020-03-05
license Other

GAN Compression

paper | demo

[NEW!] The PyTorch implementation of a general conditional GAN Compression framework is released.

teaserWe introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, and GauGAN, by 9-21x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.

GAN Compression: Efficient Architectures for Interactive Conditional GANs Muyang Li, Ji Lin, Yaoyao Ding, Zhijian Liu, Jun-Yan Zhu, and Song Han MIT, Adobe Research, SJTU In CVPR 2020.

Demos

Overview

overviewGAN Compression framework: ① Given a pre-trained teacher generator G', we distill a smaller “once-for-all” student generator G that contains all possible channel numbers through weight sharing. We choose different channel numbers for the student generator G at each training step. ② We then extract many sub-generators from the “once-for-all” generator and evaluate their performance. No retraining is needed, which is the advantage of the “once-for-all” generator. ③ Finally, we choose the best sub-generator given the compression ratio target and performance target (FID or mAP), perform fine-tuning, and obtain the final compressed model.

Colab Notebook

PyTorch Colab notebook: CycleGAN and pix2pix.

Prerequisites

  • Linux
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:

    git clone git@github.com:mit-han-lab/gan-compression.git
    cd gan-compression
    
  • Install PyTorch 1.4 and other dependencies (e.g., torchvision).

    • For pip users, please type the command pip install -r requirements.txt.
    • For Conda users, we provide an installation script ./scripts/conda_deps.sh. Alternatively, you can create a new Conda environment using conda env create -f environment.yml.
  • Install torchprofile.

    pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git
    

CycleGAN

Setup

  • Download the CycleGAN dataset (e.g., horse2zebra).

    bash datasets/download_cyclegan_dataset.sh horse2zebra
    
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistic information for several datasets. For example,

    bash ./datasets/download_real_stat.sh horse2zebra A
    bash ./datasets/download_real_stat.sh horse2zebra B
    

Apply a Pre-trained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full
    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
    
  • Test the original full model.

    bash scripts/cycle_gan/horse2zebra/test_full.sh
    
  • Test the compressed model.

    bash scripts/cycle_gan/horse2zebra/test_compressed.sh
    

Pix2pix

Setup

  • Download the pix2pix dataset (e.g., edges2shoes).

    bash ./datasets/download_pix2pix_dataset.sh edges2shoes-r
    
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,

    bash datasets/download_real_stat.sh edges2shoes-r B
    

Apply a Pre-trained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full
    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
    
  • Test the original full model.

    bash scripts/pix2pix/edges2shoes-r/test_full.sh
    
  • Test the compressed model.

    bash scripts/pix2pix/edges2shoes-r/test_compressed.sh
    

Cityscapes Dataset

For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script datasets/prepare_cityscapes_dataset.py to preprocess it. You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip and unzip them in the same folder. For example, you may put gtFine and leftImg8bit in database/cityscapes-origin. You could prepare the dataset with the following command:

python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--table_path ./dataset/table.txt

You will get a preprocessed dataset in database/cityscapes and a mapping table (used to compute mAP) in dataset/table.txt.

FID Computation

To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script get_real_stat.py to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:

python get_real_stat.py \
--dataroot database/edges2shoes-r \
--output_path real_stat/edges2shoes-r_B.npz \
--direction AtoB

Citation

If you use this code for your research, please cite our paper.

@inproceedings{li2020gan,
  title={GAN Compression: Efficient Architectures for Interactive Conditional GANs},
  author={Li, Muyang and Lin, Ji and Ding, Yaoyao and Liu, Zhijian and Zhu, Jun-Yan and Han, Song},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2020}
}

Acknowledgements

Our code is developed based on pytorch-CycleGAN-and-pix2pix and SPADE.

We also thank pytorch-fid for FID computation and drn for mAP computation.

comments powered by Disqus