Compress & Optimize Your Deep Neural Network With Pruning

Sahan Dilshan
9 min readMar 3, 2021

--

Neural Network Compression

If you are reading this blog, I’m pretty sure you are familiar with Deep-Learning concepts, or at least you had implemented and trained a deep learning model before. Today, Deep Neural Networks are everywhere. They have been using to give recommendations, make decisions, language translations, object detection, and so on. In this digital era knowing how to create your own deep learning model will come handy in many ways. But, is it enough just to know how to create a deep learning model?

When you created a deep learning model, have you ever checked how many parameters available in your model or how much memory is allocated to your model. If you do so have you ever tried to reduce the memory usage or reduce the total number of parameters without changing the model structure? Whaaaaaaat???, how to reduce the parameter count or memory size without changing the model structure. If you had this same thought in your head right now, trust me you have never heard about Neural Network Compression. From this blog, I’m going to cover the basics of Neural Network Compression and guide you on how to compress a deep learning model step by step.

Neural Network Compression

Overparameterized Neural Networks have already shown impressive performance in various domains such as computer vision and Natural Language Processing. But due to this over parameterization, these model needs lots of memory and storage requirements. Hence these models are unable to run with memory constraints devices. Neural Network Compression comes to address this issue. There are various techniques and algorithm which can be used to compress a Neural Network.

  1. Weight Sharing
  2. Network Pruning
  3. Low-Rank Matrix & Tensor Decompositions
  4. Knowledge Distillation
  5. Quantization

Above mentioned methods are some popular and mostly used compression techniques that are used to compress a Neural Network. From this article, I’m only going to talk about Network Pruning. (will talk about other techniques in upcoming articles).

Want to read this story later? Save it in Journal.

Compress Deep Learning Model with Pruning

As you already know, Neural Networks are replicating the process of the brain. They train on unknown data over and over again, then identifying patterns and features of those data so later they can be used to give meaningful outputs. The human brain does the exact same thing inside our heads. That’s why it takes some time for infants to talk. The process of pruning is also an action replicated from the brain.

BiLogical Inspiration of Pruning
Biological inspiration for pruning(Huttenlocher, 2013)
Deep Neural Network Pruning

one picture is worth a thousand words, In this case, two pictures actually. The first pictures show how many synapses will be available in the human brain at different ages. As the image shows, a year-old infant has as twice as many synapses as an adolescent. Does that mean a one-year-old infant has more knowledge than an adolescent? No, As some doctors and scientists say, there will be synapses in our brain that will not pass a single signal throughout our lifetime. So, rather than having a huge number of unused synapses, our brain will try to drop the unused synapses as much as it can. That is why the adolescent has fewer synapses than a year-old infant.

Mechanism of the pruning does the exact same thing, As the second image shows it will find out less important parameters(weights) and remove those from the Neural Network. But how does the Neural Network identifies what are the important parameters and what are the unimportant parameters? Simple, it will use a threshold value and compare each and every parameter in the network with the given threshold and remove the parameters which are less than the given threshold. This process can be achieved in two different way,

  1. Weight Pruning — Removes the connection between Neurons in adjacent layers (Does not change the shape of the Neural Network, aka. Unstructered Pruning).
  2. Neuron Pruning — Removes the entre Neuron and its connection from the layer (Change the shape of the Neural Network, aka. Structured Pruning)

In this article, I’m going to talk about Weight Pruning.

Now let’s see how to apply pruning to an existing model. First of all, we need to have a model. For that, I’m going to use the MNIST dataset with a small model (The hello world program in deep Learning 😄)

Pruning With PyTorch

In this section, I’m gonna show how to apply pruning on a Deep Neural Network with PyTorch.

Creating the Simple model for MNIST Fashion dataset

I’m not going to talk about line by line explanation about the above code since you guys are already familiar with Deep Neural Network models. Important things to remember here is the model architecture. from lines 4–7, clearly shows what kind of layers are going to use in this model and their dimensions. For this tutorial, all gonna be just fully connected layers (Linear layers).

Model accuracy on the test dataset

Just after five epochs, the model was able to gain about 86% accuracy on the test dataset. Now let’s discuss the big topic. How to compress this model using pruning.

Prune the model

First of all, let’s view the state dictionary of the model

State Dictionary of the model

In the above picture, the first column shows the tensor name and the 2nd column shows the dimensions of that tensor. When applying the pruning we have to select for which tensors are we going to apply the pruning. The State Dictionary is really helpful to get an idea about that. You already know that every Linear Layer consists of weights and bias. Those are represented in two tensors separately. Basically for one Linear layer (Fully Connected layer), there are two tensors available. Weights and Bias tensors. Those are represented in the state dictionary clearly. Usually, when we prune a model, we apply the pruning only for weights tensor since it’s the high dimensional tensor and pruning bias can be damage the model lot worse than pruning weights. So in this case I have chosen to prune the following tensors in the model,

  • fc1.weight
  • fc2.weight
  • fc3.weight
  • fc4.weight
Compress the model with pruning

This is the code that we have to go through line by line.

from torch.nn.utils import prune

PyTorch has already implemented functions to prune a model in different ways. All of those functions are available under this library.

parameters_to_prune = (
(pruned_model.fc1, 'weight'),
(pruned_model.fc2, 'weight'),
(pruned_model.fc3, 'weight'),
(pruned_model.fc4, 'weight'),
)

Here we specified a set of parameters to prune. First, we have to give the layer we need to prune (prune_model.fc1, prune_model.fc2 ……), then we have to specify which tensor we need to prune in the given layer ('weight' ).

prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # Specifying the percentage
)

After specifying the tensors we need to prune, then we call the global_unstructed function to prune the model. There are other pruning functions like random_unstructured, ln_structured which will be not discussed here since global_unstructed pruning comes in handy in many ways. And it has been widely used than other functions. The real special thing about Global Unstructured pruning is that it applies the pruning to the entire network, rather than pruning just a single layer.

Now let’s discuss what kind of arguments we should provide to the globale_usntructured function. As the first argument, we need to provide the set of parameters(Tensors) which we expected to prune in the model (we have defined that set above). Then as the second argument, we have to provide a pruning method to prune each tensor in the model. In this example, I have used L1Unstructured as the pruning method. So the PyTorch will use this function to prune each single tensor [(pruned_model.fc1, ‘weight’), (pruned_model.fc2, ‘weight’)…] . You can learn more about L1Unstructered from the PyTorch official documentation. Basically, it’s removing weights (zeroing out the weights) with the lowest L1-normalization. Then as the third and final argument, we have to provide the pruning percentage. This percentage tells the global pruning to prune the entire model from this much of a percentage. In this example, I have given 0.4 (40%) as the pruning percentage.

*Note that this percentage applies to the entire model. It’s not compressing each tensor from 40%. It will find the best way to prune (Best threshold) all the parameters from 40%. Hence different tensors may have different compression ratios at the end. But eventually, the entire network will be prune from 40%.

If we print the state dictionary of the model after running those lines, we can see an output like this.

State Dictionary of the model after pruning

In this output, you can notice that there are some tensors that have been repeated twice with two different suffixes (fc1.weight_orig, fc1.weight_mask….). If you can remember, these are the tensors that we have given to the globale_usntructured pruning to prune. In PyTorch, when we ask to prune a tensor. it will not apply the pruning directly to that tensor. Basically, it creates a copy of that tensor and applies the pruning to that copy. That’s the reason why we can see two different names with the same tenser name. weight_orig is the original tensor and the weight_mask is the pruned tensor. So in order to complete the pruning, we have to remove the original tensors from the model and keep only the pruned tensors (masked tensors).

prune.remove(pruned_model.fc1, 'weight')
prune.remove(pruned_model.fc2, 'weight')
prune.remove(pruned_model.fc3, 'weight')
prune.remove(pruned_model.fc4, 'weight')

To remove the original layers from the model, we have to execute the above set of lines. Basically under the prune.remove() method, we have to provide the layer and the tensor we need to remove. Here we do not need to pass the weight_orig as the tensor since prune.remove() will automatically remove the original layer and keep the mask layer as the default layer. Therefore we just need to pass the weight as the 2nd argument. Now if you get the state dictionary output, you will no longer see that duplicate tensor outputs. It will be a normal state dictionary output like before.

Model Performance

Now let’s see the pruned model performance and effectiveness.

Pruning statistics

From the above picture, you can see the total number of parameters available in the original model and the total number of parameters available in the Pruned model. Basically, the original model has nearly 242K parameter counts and the pruned model just has nearly 145K parameters count. That’s mean that the pruned model was able to drop nearly 96K parameters which is equivalent to 40% of a compression ratio. Now let’s see how this compression has affected the model accuracy.

Model accuracy

I know you will not believe this result. You would expect kind of a huge accuracy drop from a 40% compression. Not only you, I was also surprised when I tried out and apply pruning for the first time to a model like this. The final result really surprised me. Like how can we expect more accuracy or some closer accuracy to the original model from a smaller parameterized network?. After gone through some research papers and discussions I was able to find the answer. Model pruning helps us to find the best-optimized network in the original network. In simple terms, Model pruning finds the best sub-network from the original and larger neural network. More parameters do not always mean better accuracy. As you might already know, an overparameterized network does not perform really well with the dataset always. So how do we know our model is over parametrized or not. Simple, we need to create multiple models with different sizes and train all models to see which is the best model 😵 or we can just apply the pruning 😉.

You can get the full code of this example from this Github repo. When you run this example on your machines you might get different accuracy results than this. You can also try and change model structures and compress the model with different ratios to see which is the best model for the fashion MNIST dataset. I hope you learned something new from this article. Feel free to ask any questions and give suggestions. Then see you again with a new article on another interesting topic 👋👋👋.

📝 Save this story in Journal.

--

--

Sahan Dilshan

Software Engineer @WSO2 | NLP/Deep Learning Enthusiastic