295

Jensen-Shannon Divergence Based Loss Functions for Bayesian Neural Networks

Neurocomputing (Neurocomputing), 2022
Abstract

The Kullback-Leibler (KL) divergence is widely used for the variational inference of Bayesian Neural Networks (BNNs) to approximate the posterior distribution of weights. However, the KL divergence is unbounded and asymmetric, which may lead to instabilities during optimization or may yield poor generalizations. To overcome these limitations, we examine the Jensen-Shannon (JS) divergence that is more general, bounded, and symmetric. Towards this, we propose two novel loss functions for BNNs: 1) a geometric JS divergence (JS-G) based loss function that is symmetric but unbounded with closed-form expression for Gaussian priors and 2) a generalized JS divergence (JS-A) based loss function that is symmetric and bounded. We show that the conventional KL divergence-based loss function is a special case of the loss functions presented in this work. To evaluate the divergence part of the proposed JS-G-based loss function, we use an exact closed-form expression for Gaussian priors. For any other priors of JS-G and for the JS-A-based loss function we use Monte Carlo approximation. We provide algorithms to optimize the loss function using both these methods. The proposed loss functions offer additional parameters that can be tuned to control the regularisation. We explain the reason why the proposed loss functions should perform better than the state-of-the-art. Further, we derive the conditions under which the proposed JS-G-loss function regularises better than the KL divergence-based loss function for Gaussian priors and posteriors. The proposed JS divergence-based Bayesian convolutional neural networks (BCNN) perform better than the state-of-the-art BCNN, which is shown for the classification of the CIFAR data set having various degrees of noise and a biased histopathology data set.

View on arXiv
Comments on this paper