JaxPruner: A concise library for sparsity research
Abstract
JaxPruner is an open-source JAX-based library that provides efficient implementations of pruning and sparse training algorithms with seamless integration capabilities across multiple machine learning frameworks.
This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks.
Models citing this paper 0
No model linking this paper
Datasets citing this paper 1
Spaces citing this paper 0
No Space linking this paper
Collections including this paper 0
No Collection including this paper