JaxSGMC: Modular stochastic gradient MCMC in JAX

Stephan Thaler, Paul Fuchs, Ana Cukarska, Julija Zavadlav

Publikation: Beitrag in FachzeitschriftArtikelBegutachtung

1 Zitat (Scopus)

Abstract

We present JaxSGMC, an application-agnostic library for stochastic gradient Markov chain Monte Carlo (SG-MCMC) in JAX. SG-MCMC schemes are uncertainty quantification (UQ) methods that scale to large datasets and high-dimensional models, enabling trustworthy neural network predictions via Bayesian deep learning. JaxSGMC implements several state-of-the-art SG-MCMC samplers to promote UQ in deep learning by reducing the barriers of entry for switching from stochastic optimization to SG-MCMC sampling. Additionally, JaxSGMC allows users to build custom samplers from standard SG-MCMC building blocks. Due to this modular structure, we anticipate that JaxSGMC will accelerate research into novel SG-MCMC schemes and facilitate their application across a broad range of domains.

OriginalspracheEnglisch
Aufsatznummer101722
FachzeitschriftSoftwareX
Jahrgang26
DOIs
PublikationsstatusVeröffentlicht - Mai 2024

Fingerprint

Untersuchen Sie die Forschungsthemen von „JaxSGMC: Modular stochastic gradient MCMC in JAX“. Zusammen bilden sie einen einzigartigen Fingerprint.

Dieses zitieren