SUS backprop: linear backpropagation algorithm for long inputs in transformers

It is straightforward to design an unbiased gradient estimator that stochastically cuts the backpropagation flow through any part of a computational graph. By cutting the parts that have little effect on the computation, one can potentially save a significant amount of back-propagation computation in exchange for a minimal increase in the stochastic gradient variance, in some situations. Such a situation occurs in the attention mechanism of the transformer architecture. For long sequences, attention becomes the limiting factor, as its compute requirements increase quadratically with sequence length . At the same time, most attention weights become very small, as most attention heads tend to connect a given token with only a small fraction of other tokens in the sequence. These weights become promising targets for cutting backpropagation. We propose a simple probabilistic rule controlled by a single parameter that cuts backpropagation through most attention weights, leaving at most interactions per token per attention head. This brings a factor of reduction in the compute required for the attention backpropagation, turning it from quadratic to linear complexity . We have empirically verified that, for a typical transformer model, cutting of the attention gradient flow (i.e. choosing ) results in relative gradient variance increase of only about for , and it decreases with . This approach is amenable to efficient sparse matrix implementation, thus being promising for making the cost of a backward pass negligible relative to the cost of a forward pass when training a transformer model on long sequences.
View on arXiv@article{pankov2025_2505.15080, title={ SUS backprop: linear backpropagation algorithm for long inputs in transformers }, author={ Sergey Pankov and Georges Harik }, journal={arXiv preprint arXiv:2505.15080}, year={ 2025 } }