Summary: Learning both Weights and Connections for NNs
Contents
Download the paper
What is the paper about?
- It presents a three-step method (train → prune → retrain) to remove redundant connections in neural networks.
- The core idea is to identify and keep only the “important” weights, thus producing a sparse network that reduces storage/memory costs and power consumption.
- The method can compress networks like AlexNet or VGG by up to 9×–13× while preserving nearly the same accuracy.
What is new about this specific paper, compared to prior work?
- Unlike traditional pruning methods (e.g., Optimal Brain Damage/Surgeon), the authors propose a simpler, magnitude-based threshold that is repeated iteratively to prune connections.
- They emphasize learning the connectivity of the network rather than just the weights, effectively tuning both structure and parameters.
- They demonstrate that retraining after pruning is essential for maintaining accuracy and show the advantage of iterative pruning versus a single-shot approach.
- They highlight that both convolutional and fully connected layers can be successfully pruned, whereas some prior works primarily focused on fully connected layers.
What experiments were run to support the arguments in this paper?
- Experiments on MNIST (LeNet-300-100 and LeNet-5) to show large reductions in parameter count (12×) with no accuracy loss.
- Experiments on ImageNet using AlexNet and VGG-16, showing 9× and 13× compression, respectively.
- Layer-by-layer sensitivity analyses (pruning each layer to different extents) to gauge how pruning affects accuracy.
- Comparisons of different regularization schemes (L1 vs. L2), with and without retraining, to confirm that L2 + iterative retraining preserves the most accuracy.
What are the shortcomings/limitations of this paper?
- The resulting sparsity is unstructured, making it harder to accelerate on standard GPUs, which typically thrive on structured regularity. Hardware specialized for sparse operations is less common.
- The approach relies on a threshold hyperparameter (proportional to weight standard deviation) that can be tricky to tune optimally.
- Pruning is iterative and requires additional retraining time, which might be expensive for very large networks.
- It has primarily been tested on classification tasks (MNIST, ImageNet); performance on other tasks (e.g., object detection or language modeling) may vary.
What is a reasonable next step to build upon this paper?
- Structured pruning approaches (e.g., removing entire neurons, channels, or filters) could simplify deployment on existing hardware, providing better speedup in practice.
- Combine pruning with other compression techniques (quantization, low-rank factorization) for even higher efficiency.
- Investigate pruning for more diverse tasks (object detection, speech recognition, NLP) and large-scale networks.
- Explore automated or adaptive threshold selection methods to reduce manual hyperparameter tuning.
Appendix
Optimal Brain Damage / Optimal Brain Surgeon: Classic pruning algorithms from the early 1990s, using second-order information (the Hessian matrix of the loss function) to identify parameters that contribute least to the network’s performance.
L1/L2 Regularization
- L1 regularization penalizes the absolute value of weights, pushing many weights to become exactly zero, which encourages sparsity.
- L2 regularization (weight decay) penalizes the square of weights and tends to keep weights small but non-zero (Note:
gradient < 1.0
so square is smaller).