11.4 Fusing Conv and Batch Norm

Fusing Convolution and Batch Norm using Custom Function

Created Date: 2025-07-21

Fusing adjacent convolution and batch norm layers together is typically an inference-time optimization to improve run-time. It is usually achieved by eliminating the batch norm layer entirely and updating the weight and bias of the preceding convolution [0]. However, this technique is not applicable for training models.

In this tutorial, we will show a different technique to fuse the two layers that can be applied during training. Rather than improved runtime, the objective of this optimization is to reduce memory usage.

The idea behind this optimization is to see that both convolution and batch norm (as well as many other ops) need to save a copy of their input during forward for the backward pass. For large batch sizes, these saved inputs are responsible for most of your memory usage, so being able to avoid allocating another input tensor for every convolution batch norm pair can be a significant reduction.

In this tutorial, we avoid this extra allocation by combining convolution and batch norm into a single layer (as a custom function). In the forward of this combined layer, we perform normal convolution and batch norm as-is, with the only difference being that we will only save the inputs to the convolution. To obtain the input of batch norm, which is necessary to backward through it, we recompute convolution forward again during the backward pass.

It is important to note that the usage of this optimization is situational. Though (by avoiding one buffer saved) we always reduce the memory allocated at the end of the forward pass, there are cases when the peak memory allocated may not actually be reduced. See the final section for more details.

For simplicity, in this tutorial we hardcode bias=False, stride=1, padding=0, dilation=1, and groups=1 for Conv2D. For BatchNorm2D, we hardcode eps=1e-3, momentum=0.1, affine=False, and track_running_statistics=False. Another small difference is that we add epsilon in the denominator outside of the square root in the computation of batch norm.