83
2

Slax: A Composable JAX Library for Rapid and Flexible Prototyping of Spiking Neural Networks

Abstract

Recent advances to algorithms for training spiking neural networks (SNNs) often leverage their unique dynamics. While backpropagation through time (BPTT) with surrogate gradients dominate the field, a rich landscape of alternatives can situate algorithms across various points in the performance, bio-plausibility, and complexity landscape. Evaluating and comparing algorithms is currently a cumbersome and error-prone process, requiring them to be repeatedly re-implemented. We introduce Slax, a JAX-based library designed to accelerate SNN algorithm design, compatible with the broader JAX and Flax ecosystem. Slax provides optimized implementations of diverse training algorithms, allowing direct performance comparison. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training.

View on arXiv
Comments on this paper

We use cookies and other tracking technologies to improve your browsing experience on our website, to show you personalized content and targeted ads, to analyze our website traffic, and to understand where our visitors are coming from. See our policy.