13.2 Pruning

Pruning Tutorial

Created Date: 2025-06-26

State-of-the-art deep learning techniques rely on over-parametrized models that are hard to deploy. On the contrary, biological neural networks are known to use efficient sparse connectivity. Identifying optimal techniques to compress models by reducing the number of parameters in them is important in order to reduce memory, battery, and hardware consumption without sacrificing accuracy. This in turn allows you to deploy lightweight models on device, and guarantee privacy with private on-device computation.

On the research front, pruning is used to investigate the differences in learning dynamics between over-parametrized and under-parametrized networks, to study the role of lucky sparse subnetworks and initializations (“lottery tickets”) as a destructive neural architecture search technique, and more.

In this tutorial, you will learn how to use torch.nn.utils.prune to sparsify your neural networks, and how to extend it to implement your own custom pruning technique.

13.2.1 Create a Model

In this tutorial, we use the LeNet architecture from LeCun et al., 1998.

13.2.2 Inspect a Module

Let’s inspect the (unpruned) conv1 layer in our LeNet model. It will contain two parameters weight and bias, and no buffers, for now.

13.2.3 Pruning a Module

To prune a module (in this example, the conv1 layer of our LeNet architecture), first select a pruning technique among those available in torch.nn.utils.prune (or implement your own by subclassing BasePruningMethod). Then, specify the module and the name of the parameter to prune within that module. Finally, using the adequate keyword arguments required by the selected pruning technique, specify the pruning parameters.

In this example, we will prune at random 30% of the connections in the parameter named weight in the conv1 layer. The module is passed as the first argument to the function; name identifies the parameter within that module using its string identifier; and amount indicates either the percentage of connections to prune (if it is a float between 0. and 1.), or the absolute number of connections to prune (if it is a non-negative integer).

13.2.4 Iterative Pruning

13.2.5 Serializing a Pruned Model

13.2.6 Remove Pruning Re-parametrization

13.2.7 Pruning Multiple Parameters in a Model

By specifying the desired pruning technique and parameters, we can easily prune multiple tensors in a network, perhaps according to their type, as we will see in this example.

13.2.8 Global pruning

So far, we only looked at what is usually referred to as “local” pruning, i.e. the practice of pruning tensors in a model one by one, by comparing the statistics (weight magnitude, activation, gradient, etc.) of each entry exclusively to the other entries in that tensor. However, a common and perhaps more powerful technique is to prune the model all at once, by removing (for example) the lowest 20% of connections across the whole model, instead of removing the lowest 20% of connections in each layer. This is likely to result in different pruning percentages per layer. Let’s see how to do that using global_unstructured from torch.nn.utils.prune .

Now we can check the sparsity induced in every pruned parameter, which will not be equal to 20% in each layer. However, the global sparsity will be (approximately) 20% .

13.2.9 Extending torch.nn.utils.prune

To implement your own pruning function, you can extend the nn.utils.prune module by subclassing the BasePruningMethod base class, the same way all other pruning methods do. The base class implements the following methods for you: __call__, apply_mask, apply, prune, and remove.

Beyond some special cases, you shouldn’t have to reimplement these methods for your new pruning technique. You will, however, have to implement __init__ (the constructor), and compute_mask (the instructions on how to compute the mask for the given tensor according to the logic of your pruning technique).

In addition, you will have to specify which type of pruning this technique implements (supported options are global, structured, and unstructured). This is needed to determine how to combine masks in the case in which pruning is applied iteratively. In other words, when pruning a prepruned parameter, the current pruning technique is expected to act on the unpruned portion of the parameter. Specifying the PRUNING_TYPE will enable the PruningContainer (which handles the iterative application of pruning masks) to correctly identify the slice of the parameter to prune.

Let’s assume, for example, that you want to implement a pruning technique that prunes every other entry in a tensor (or – if the tensor has previously been pruned – in the remaining unpruned portion of the tensor). This will be of PRUNING_TYPE='unstructured' because it acts on individual connections in a layer and not on entire units/channels ('structured'), or across different parameters ('global').