sthalles/SimCLR
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
repo name | sthalles/SimCLR |
repo link | https://github.com/sthalles/SimCLR |
homepage | https://sthalles.github.io/simple-self-supervised-learning/ |
language | Jupyter Notebook |
size (curr.) | 84446 kB |
stars (curr.) | 676 |
created | 2020-02-17 |
license | MIT License |
PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
Blog post with full documentation: Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
See also PyTorch Implementation for BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning.
Installation
$ conda env create --name simclr --file env.yml
$ conda activate simclr
$ python run.py
Config file
Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the run.py
file.
$ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100
If you want to run it on CPU (for debugging purposes) use the --disable-cuda
option.
For 16-bit precision GPU training, make sure to install NVIDIA apex and use the --fp16_precision
flag.
Feature Evaluation
Feature evaluation is done using a linear model protocol.
First, we learned features using SimCLR on the STL10 unsupervised
set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linera model is trained on features extracted from the STL10 train
set and evaluated on the STL10 test
set.
Check the notebook for reproducibility.
Note that SimCLR benefits from longer training.
Linear Classification | Dataset | Feature Extractor | Architecture | Feature dimensionality | Projection Head dimensionality | Epochs | Top1 % |
---|---|---|---|---|---|---|---|
Logistic Regression (Adam) | STL10 | SimCLR | ResNet-18 | 512 | 128 | 100 | 74.45 |
Logistic Regression (Adam) | CIFAR10 | SimCLR | ResNet-18 | 512 | 128 | 100 | 69.82 |
Logistic Regression (Adam) | STL10 | SimCLR | ResNet-50 | 2048 | 128 | 50 | 70.075 |