Image classification with synthetic gradient in Pytorch
I implement the Decoupled Neural Interfaces using Synthetic Gradients in pytorch.
The paper uses synthetic gradient to decouple the layers among the
network, which is pretty interesting since we won't suffer from update lock anymore. I test my model in mnist and almost the same performance, compared to the model updated with backpropagation.
Requirement
pytorch
python 3.5
torchvision
seaborn (optional)
matplotlib (optional)
TODO
use multi-threading on gpu to analyze the speed
What's synthetic gradients?
We ofter optimize NN by backpropogation, which is usually implemented
in some well-known framework. However, is there another way for the
layers in NN to communicate with other layers? Here comes the synthetic gradients!
It gives us a way to allow neural networks to communicate, to learn to
send messages between themselves, in a decoupled, scalable manner paving
the way for multiple neural networks to communicate with each other or
improving the long term temporal dependency of recurrent networks. The neuron in each layer will automatically produces an error signal(δa_head)
from synthetic-layers and do the optimzation. And how did the error
signal generated? Actually, the network still does the backpropogation.
While the error signal(δa) from the objective function is not used to optimize the neuron in the network, it is used to optimize the error signal(δa_head) produced by the synthetic-layer. The following is the illustration from the paper:
Result
Feed-Forward Network
Achieve accuracy=96% (compared to the original model, which with accuracy=97%)Convolutional Neural Network
Achieve accuracy=96%, (compared to the original model, which with accuracy=98%)
Usage
Right now I just implement the FCN, CNN versions, which are set as the default network structure.