A comparison plot for sparse Bayesian regression with a half-cauchy prior (L), support vector regression (C), and the relevance vector machine (R). We see the half-cauchy prior provides a more sparse solution and better error bars.
This package provides an implementation of the algorithm described in Louizos et. al. (2017) for use on a broad class of machine learning problems.
pip install variationalsparsebayes
The library provides a high-level interface with some prebuilt sparse Bayesian models and a low-level interface for building custom sparse Bayesian models.
The library provides a few sparse Bayesian models:
- sparse polynomial regression
- sparse Bayesian neural networks.
- sparse learning with precomputed features
To implement your own linear model, you can inherit from the SparseFeaturesLibrary class. Note that I haven't implemented the "group" sparsity idea presented in Louizos et. al. (2017). Sparsification is performed at the parameter level (meaning far less computational savings).
The most important class provided by the library is the SVIHalfCauchyPrior. The class inherits from nn.Module. The user is responsible for (i) transforming a batch of weights from the variational posterior into a batch of predictions and (ii) adding the KL-divergence provided by the prior onto the negative ELBO.
from torch import nn
from variationalsparsebayes import SVIHalfCauchyPrior
class MyModel(nn.Module):
def __init__(self, num_params: int):
super().__init__()
# we initialize the prior with tau=1e-5 (see https://arxiv.org/pdf/1705.08665.pdf)
self.prior = SVIHalfCauchyPrior(num_params, 1e-5)
...
def forward(self, x, num_reparam_samples):
w_samples = self.prior.get_reparam_weights(num_reparam_samples)
sparse_index = self.prior.sparse_index
# user transforms weights and inputs into predictions
...
def elbo(self, x, y):
return log_like(x, y) - self.prior.kl_divergence()
model = MyModel(num_params)
...
When it comes time to sparsify the approximate posterior run:
model.prior.update_sparse_index()
# get the index of all weights which remain after sparsification
model.prior.sparse_index