Optimal Gradient Checkpoint Search for Arbitrary Computation Graphs

Deep Neural Networks(DNNs) require huge GPU memory when training on modern image/video databases. Unfortunately, the GPU memory in off-the-shelf devices is always finite, which limits the image resolutions and batch sizes that could be used for better DNN performance. Existing approaches to alleviate memory issue include better GPUs, distributed computation and Gradient CheckPointing(GCP) training. Among them, GCP is a favorable approach as it focuses on trading computation for memory and does not require any upgrades on hardware. In GCP, during forward, only a subset of intermediate tensors are stored, which are called Gradient Checkpoints (GCs). Then during backward, extra local forwards are conducted to compute the missing tensors. The total training memory cost becomes the sum of (1) the memory cost of the gradient checkpoints and (2) the maximum memory cost of local forwards. To achieve maximal memory cut-offs, one needs optimal algorithms to select GCs. Existing GCP approaches rely on either manual input of GCs or heuristics-based GC search on linear computation graphs (LCGs), and cannot apply to arbitrary computation graphs(ACGs). In this paper, we present theories and optimal algorithms on GC selection that, for the first time, are applicable on ACGs and achieve maximal memory cut-offs. Extensive experiments show that our approach constantly outperforms existing approaches on LCGs, and can cut off up-to 80% of training memory\footnote{Cutting off 80% of training memory means one can double the input image size or quadruple the batch size on the same GPUs.} with a moderate time overhead (around 40%) on LCG and ACG DNNs, such as Alexnet, VGG, Resnet, Densenet and Inception Net.
View on arXiv