13.2 Pruning
Pruning Tutorial
Created Date: 2025-06-26
State-of-the-art deep learning techniques rely on over-parametrized models that are hard to deploy. On the contrary, biological neural networks are known to use efficient sparse connectivity. Identifying optimal techniques to compress models by reducing the number of parameters in them is important in order to reduce memory, battery, and hardware consumption without sacrificing accuracy. This in turn allows you to deploy lightweight models on device, and guarantee privacy with private on-device computation.
On the research front, pruning is used to investigate the differences in learning dynamics between over-parametrized and under-parametrized networks, to study the role of lucky sparse subnetworks and initializations (“lottery tickets”) as a destructive neural architecture search technique, and more.
In this tutorial, you will learn how to use torch.nn.utils.prune to sparsify your neural networks, and how to extend it to implement your own custom pruning technique.
13.2.1 Create a Model
In this tutorial, we use the LeNet architecture from LeCun et al., 1998.
13.2.2 Inspect a Module
Let’s inspect the (unpruned) conv1 layer in our LeNet model. It will contain two parameters weight and bias, and no buffers, for now.
13.2.3 Pruning a Module
To prune a module (in this example, the conv1 layer of our LeNet architecture), first select a pruning technique among those available in torch.nn.utils.prune (or implement your own by subclassing BasePruningMethod). Then, specify the module and the name of the parameter to prune within that module. Finally, using the adequate keyword arguments required by the selected pruning technique, specify the pruning parameters.
In this example, we will prune at random 30% of the connections in the parameter named weight in the conv1 layer. The module is passed as the first argument to the function; name identifies the parameter within that module using its string identifier; and amount indicates either the percentage of connections to prune (if it is a float between 0. and 1.), or the absolute number of connections to prune (if it is a non-negative integer).