A PyTorch >=1.0 implementation of DenseNets, optimized to save GPU memory.
Recent updates
Now works on PyTorch 1.0! It uses the checkpointing feature, which makes this code WAY more efficient!!!
Motivation
While DenseNets are fairly easy to implement in deep learning frameworks, most implmementations (such as the original) tend to be memory-hungry. In particular, the number of intermediate feature maps generated by batch normalization and concatenation operations grows quadratically with network depth. It is worth emphasizing that this is not a property inherent to DenseNets, but rather to the implementation.
This implementation uses a new strategy to reduce the memory consumption of DenseNets. We use checkpointing to compute the Batch Norm and concatenation feature maps. These intermediate feature maps are discarded during the forward pass and recomputed for the backward pass. This adds 15-20% of time overhead for training, but reduces feature map consumption from quadratic to linear.
This implementation is inspired by this technical report, which outlines a strategy for efficient DenseNets via memory sharing.
Requirements
PyTorch >=1.0.0
CUDA
Usage
In your existing project: There is one file in the models folder.
models/densenet.py is an implementation based off the torchvision and project killer implementations.
If you care about speed, and memory is not an option, pass the efficient=False argument into the DenseNet constructor. Otherwise, pass in efficient=True.
@article{pleiss2017memory,
title={Memory-Efficient Implementation of DenseNets},
author={Pleiss, Geoff and Chen, Danlu and Huang, Gao and Li, Tongcheng and van der Maaten, Laurens and Weinberger, Kilian Q},
journal={arXiv preprint arXiv:1707.06990},
year={2017}
}