Gaussian processes (2/3) - Fitting a Gaussian process kernel

In the previous post we introduced the Gaussian process model with the exponentiated quadratic covariance function. In this post we will introduce parametrized covariance functions (kernels), fit them to real world data, and use them to make posterior predictions. This post is the second part of a series on Gaussian processes:

  1. Understanding Gaussian processes
  2. Fitting a Gaussian process kernel (this)
  3. Gaussian process kernels

We will implement the Gaussian process model in JAX , keeping the kernel and likelihood calculations explicit so the model details remain visible.

In [1]:

Mauna Loa CO₂ data

The dataset used in this example is the monthly average atmospheric CO₂ concentrations (in parts per million (ppm)) collected at the Mauna Loa Observatory in Hawaii. The observatory has been collecting these CO₂ concentrations since 1958 and showed the first significant evidence of rapidly increasing CO₂ levels in the atmosphere.

These measures of atmospheric CO₂ concentrations show different characteristics such as a long term rising trend, variation with the seasons, and smaller irregularities. This made it into a canonical example in Gaussian process modelling ⁽¹⁾ .

In this post the data is downloaded as a CSV file from the Scripps CO₂ Program website . This data is loaded in a pandas dataframe and plotted below.

In [2]:
In [3]:
In [4]:
617 measurements in the observed set
133 measurements in the test set

Gaussian process model

We're going to use a Gaussian process model to make posterior predictions of the atmospheric CO₂ concentrations for 2010 and after based on the observed data from before 2010.

A Gaussian process is uniquely defined by its mean function $m(x)$ and covariance function $k(x,x')$:

$$f(x) \sim \mathcal{GP}(m(x),k(x,x'))$$

Mean function

Since most interesting effects will be modelled by the kernel function we will keep the mean function simple. In this example the mean function is going to be modelled as a function that always returns the mean of the observations.

In [5]:

Kernel function

To model the different characteristics of our dataset we will create a covariance (kernel) function by combining a few small JAX kernel functions. The different data characteristics will be modelled as:

  • Long term smooth change in CO₂ levels over time modelled by an exponentiated quadratic kernel defined in the code below as smooth_kernel .
  • Seasonality based on a local periodic kernel, which consists of a periodic kernel multiplied with an exponentiated quadratic to make the seasonality decay as it gets farther from the observations. This seasonal periodic kernel is defined in the code below as local_periodic_kernel .
  • Short to medium term irregularities modelled by a rational quadratic kernel, which is defined in the code below as irregular_kernel .
  • Observational noise which will be modelled directly by adding observation_noise_variance to the kernel matrix diagonal.

These different kernels will be summed into one single kernel function $k_{\theta}(x_a, x_b)$ that will allow for all these effects to occur together. This kernel is defined as kernel in the code below. Each of the kernels has hyperparameters $\theta$ that can be tuned, they will be defined as unconstrained log-parameters and transformed to positive values before use. This post provides more insight on the kernels used here and the effect of their hyperparameters.

In [6]:
# Define the parameters for the kernel function we're fitting

class KernelParams(NamedTuple):
    """Named GP hyperparameters that remain a JAX pytree."""

    smooth_amplitude: Array
    smooth_length_scale: Array
    periodic_amplitude: Array
    periodic_length_scale: Array
    periodic_period: Array
    periodic_local_length_scale: Array
    irregular_amplitude: Array
    irregular_length_scale: Array
    irregular_scale_mixture: Array
    observation_noise_variance: Array
In [7]:
In [8]:
In [9]:
# Define the combined kernel we're going to be fitting

def combined_kernel(xa: Array, xb: Array, positive_params: KernelParams) -> Array:
    """Combine smooth, seasonal, and irregular CO2 covariance components."""
    return (
        smooth_kernel(xa=xa, xb=xb, positive_params=positive_params)
        + local_periodic_kernel(
            xa=xa,
            xb=xb,
            amplitude=positive_params.periodic_amplitude,
            periodic_length_scale=positive_params.periodic_length_scale,
            period=positive_params.periodic_period,
            local_length_scale=positive_params.periodic_local_length_scale,
        )
        + irregular_kernel(xa=xa, xb=xb, positive_params=positive_params)
    )

Tuning the hyperparameters

We can tune the hyperparameters $\theta$ of our Gaussian process model based on the data. This post uses JAX to fit the parameters by maximizing the marginal likelihood $p(\mathbf{y} \mid X, \theta)$ of the Gaussian process distribution based on the observed data $(X, \mathbf{y})$.

$$\hat{\theta} = \underset{\theta}{\text{argmax}} \left( p(\mathbf{y} \mid X, \theta) \right)$$

The marginal likelihood of the Gaussian process is the likelihood of a Gaussian distribution which is defined as:

$$p(\mathbf{y} \mid \mu, \Sigma) = \frac{1}{\sqrt{(2\pi)^d \lvert\Sigma\rvert}} \exp{ \left( -\frac{1}{2}(\mathbf{y} - \mu)^{\top} \Sigma^{-1} (\mathbf{y} - \mu) \right)}$$

The mean is calculated from the parameterized mean function using the observed data $X$ as input: $\mu_{\theta} = m_{\theta}(X)$. For the observed targets, the covariance also includes observation noise on the diagonal: $\Sigma_{\theta} = k_{\theta}(X, X) + \sigma_n^2 I$. We can then write the marginal likelihood as:

$$ p(\mathbf{y} \mid X, \theta) = \frac{1}{\sqrt{(2\pi)^d \lvert \Sigma_{\theta} \rvert}} \exp{ \left( -\frac{1}{2}(\mathbf{y} - \mu_{\theta})^\top \Sigma_{\theta}^{-1} (\mathbf{y} - \mu_{\theta}) \right)}$$

With $d$ the dimensionality of the marginal and $\lvert \Sigma_{\theta} \rvert$ the determinant of the observed-data covariance matrix. We can get rid of the exponent by taking the log and maximizing the log marginal likelihood:

$$ \log{p(\mathbf{y} \mid X, \theta)} = -\frac{1}{2}(\mathbf{y} - \mu_{\theta})^\top \Sigma_{\theta}^{-1} (\mathbf{y} - \mu_{\theta}) - \frac{1}{2} \log{\lvert \Sigma_{\theta} \rvert} - \frac{d}{2} \log{2 \pi}$$

The first term ($-0.5 (\mathbf{y} - \mu_{\theta})^\top \Sigma_{\theta}^{-1} (\mathbf{y} - \mu_{\theta})$) is the data-fit while the rest ($-0.5(\log{\lvert \Sigma_{\theta} \rvert} + d\log{2 \pi})$) is a complexity penalty, also known as differential entropy ⁽¹⁾ .

The optimal parameters $\hat{\theta}$ can then be found by minimizing the negative of the log marginal likelihood:

$$ \hat{\theta} = \underset{\theta}{\text{argmax}} \left( p(\mathbf{y} \mid X, \theta) \right) = \underset{\theta}{\text{argmin}} { \;-\log{ p(\mathbf{y} \mid X, \theta)}}$$

Since in this case the log marginal likelihood is differentiable with respect to the kernel hyperparameters, we can use a gradient-based approach to minimize the negative log marginal likelihood (NLL). In this post we will use JAX automatic differentiation and Optax to train the hyperparameters on the full observed dataset.

In [10]:

We will implement the Gaussian process marginal likelihood directly from the Gaussian distribution formula above. jax.value_and_grad evaluates the negative marginal log likelihood and its gradient, while optax.adam is used to update the unconstrained log-parameters.

In [11]:
In [12]:
# Fit hyperparameters
learning_rate = 0.01
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# Define loss and gradients
loss_and_grad = jax.value_and_grad(gp_negative_log_marginal_likelihood)


@jax.jit
def optimization_step(
    params: KernelParams,
    opt_state: optax.OptState,
    index_points: Array,
    observations: Array,
) -> tuple[KernelParams, optax.OptState, Array]:
    """Take one step toward lower full-data GP marginal likelihood."""
    loss, grads = loss_and_grad(
        params, index_points=index_points, observations=observations
    )
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = cast(KernelParams, optax.apply_updates(params, updates))
    return params, opt_state, loss


# Training loop
full_nlls: list[tuple[int, float]] = []  # Full data NLL for plotting
nb_iterations = 2500
for i in tqdm(range(nb_iterations)):
    params, opt_state, loss = optimization_step(
        params, opt_state, index_points=X_observed, observations=y_observed
    )
    if i % 10 == 0 or i == nb_iterations - 1:
        full_nlls.append((i, float(loss)))
  0%|          | 0/2500 [00:00<?, ?it/s]
In [13]:

Despite the complexity term in the log marginal likelihood it is still possible to overfit to the data. While we didn't prevent further overfitting in this post, you could prevent overfitting by adding a regularization term on the hyperparameters ⁽²⁾ , or split the observations in training an validation sets and select the best fit on your validation data from models trained on the training data.

In [14]:
Hyperparameters Value
smooth_amplitude 124.3517
smooth_length_scale 99.9925
periodic_amplitude 2.2775
periodic_length_scale 1.3713
periodic_period 0.9998
periodic_local_length_scale 99.0098
irregular_amplitude 2.3398
irregular_length_scale 1.6404
irregular_scale_mixture 0.0073
observation_noise_variance 0.0404

The fitted kernel and its components are illustrated in more detail in a follow-up post .

Posterior predictions

With the fitted kernel we can use the Gaussian process conditioning equations directly to make posterior predictions on points after 2010.

The posterior predictions conditioned on the observed data before 2010 are plotted in the figure below together with their 95% prediction interval.

Notice that our model captures the different dataset characteristics such as the trend and seasonality quite well. The predictions start deviating the further away from the observed data the model was conditioned on, together with widening prediction interval.

In [15]:
In [16]:

This post illustrated using JAX to combine multiple kernels and fit their hyperparameters on observed data. The fitted model was then used to make posterior predictions.

Sidenotes

  1. Note that the determinant $\lvert \Sigma \rvert$ is equal to the product of its eigenvalues , and that $\lvert \Sigma \rvert$ can be interpreted as the volume spanned by the covariance matrix $\Sigma$. Reducing $\lvert \Sigma \rvert$ will thus decrease the dispersion of the points coming from the distribution with covariance matrix $\Sigma$ and reduce the complexity.
  2. A fully Bayesian extension could define prior distributions for the hyperparameters and sample from the posterior with a JAX-based probabilistic programming library such as NumPyro .

References

  1. Gaussian Processes for Machine Learning. Chapter 5: Model Selection and Adaptation of Hyperparameters by Carl Edward Rasmussen and Christopher K. I. Williams.
  2. Gaussian Processes for Regression by Christopher K. I. Williams and Carl Edward Rasmussen.
In [17]:
Python implementation: CPython
Python version       : 3.13.12
IPython version      : 9.13.0

IPython: 9.13.0
bokeh  : 3.9.0
jax    : 0.8.3
numpy  : 2.4.4
optax  : 0.2.8
pandas : 2.3.3
tqdm   : 4.67.3

This post is generated from an IPython notebook file. Link to the full IPython notebook file