Gaussian processes (2/3) - Fitting a Gaussian process kernel
In the previous post we introduced the Gaussian process model with the exponentiated quadratic covariance function. In this post we will introduce parametrized covariance functions (kernels), fit them to real world data, and use them to make posterior predictions. This post is the second part of a series on Gaussian processes:
We will implement the Gaussian process model in JAX , keeping the kernel and likelihood calculations explicit so the model details remain visible.
# Imports
import warnings
warnings.simplefilter("ignore")
from typing import NamedTuple, TypeAlias, cast
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import optax
import pandas as pd
from tqdm.notebook import tqdm
import bokeh
import bokeh.io
import bokeh.plotting
import bokeh.models
from IPython.display import display, HTML
bokeh.io.output_notebook(hide_banner=True)
Array: TypeAlias = jax.Array
Scalar: TypeAlias = float | Array
np.random.seed(42)
#
Mauna Loa CO₂ data
The dataset used in this example is the monthly average atmospheric CO₂ concentrations (in parts per million (ppm)) collected at the Mauna Loa Observatory in Hawaii. The observatory has been collecting these CO₂ concentrations since 1958 and showed the first significant evidence of rapidly increasing CO₂ levels in the atmosphere.
These measures of atmospheric CO₂ concentrations show different characteristics such as a long term rising trend, variation with the seasons, and smaller irregularities. This made it into a canonical example in Gaussian process modelling ⁽¹⁾ .
In this post the data is downloaded as a CSV file from the Scripps CO₂ Program website . This data is loaded in a pandas dataframe and plotted below.
# Load the data
# Load the data from the Scripps CO2 program website.
co2_df = pd.read_csv( # ty: ignore[no-matching-overload] - pandas stub mismatch
# Source: https://scrippsco2.ucsd.edu/assets/data/atmospheric/stations/in_situ_co2/monthly/monthly_in_situ_co2_mlo.csv
filepath_or_buffer="./monthly_in_situ_co2_mlo.csv",
header=3, # Data starts here
skiprows=[4, 5], # Headers consist of multiple rows
usecols=[3, 4], # Only keep the 'Date' and 'CO2' columns
na_values="-99.99", # NaNs are denoted as '-99.99'
)
# Drop missing values
co2_df.dropna(inplace=True)
# Remove whitespace from column names
co2_df.rename(columns=lambda x: x.strip(), inplace=True)
#
# Plot data
fig = bokeh.plotting.figure(
width=600,
height=300,
x_range=bokeh.models.Range1d(start=1958, end=2021),
y_range=bokeh.models.Range1d(start=310, end=420),
)
fig.xaxis.axis_label = "Date"
fig.yaxis.axis_label = "CO₂ (ppm)"
fig.add_layout(
bokeh.models.Title(
text="In situ air measurements at Mauna Loa, Observatory, Hawaii",
text_font_style="italic",
),
"above",
)
fig.add_layout(
bokeh.models.Title(text="Atmospheric CO₂ concentrations", text_font_size="14pt"),
"above",
)
observed_source = bokeh.models.ColumnDataSource(
data={"date": co2_df.Date.to_numpy(), "co2": co2_df.CO2.to_numpy()}
)
observed_renderer = fig.line(
x="date",
y="co2",
source=observed_source,
legend_label="All data",
line_width=2,
line_color="midnightblue",
)
fig.add_tools(
bokeh.models.HoverTool(
renderers=[observed_renderer],
tooltips=[("Date", "@date{0.00}"), ("CO₂", "@co2{0.00} ppm")],
mode="vline",
line_policy="nearest",
)
)
fig.legend.location = "top_left"
fig.toolbar.autohide = True
bokeh.plotting.show(fig)
#
# Split the data into observed and to predict
date_split_predict = 2010
df_observed = co2_df[co2_df.Date < date_split_predict]
print(f"{len(df_observed)} measurements in the observed set")
df_predict = co2_df[co2_df.Date >= date_split_predict]
print(f"{len(df_predict)} measurements in the test set")
#
Gaussian process model
We're going to use a Gaussian process model to make posterior predictions of the atmospheric CO₂ concentrations for 2010 and after based on the observed data from before 2010.
A Gaussian process is uniquely defined by its mean function $m(x)$ and covariance function $k(x,x')$:
$$f(x) \sim \mathcal{GP}(m(x),k(x,x'))$$
# Define mean function which is the mean of observations
observations_mean = jnp.asarray(np.mean(df_observed.CO2.values), dtype=jnp.float64)
def mean_fn(index_points: Array) -> Array:
"""Use the empirical CO2 mean as the GP prior mean."""
return jnp.full(
shape=(index_points.shape[0],), fill_value=observations_mean, dtype=jnp.float64
)
#
Kernel function
To model the different characteristics of our dataset we will create a covariance (kernel) function by combining a few small JAX kernel functions. The different data characteristics will be modelled as:
-
Long term smooth change in CO₂ levels over time modelled by an exponentiated quadratic kernel defined in the code below as
smooth_kernel. -
Seasonality based on a local periodic kernel, which consists of a periodic kernel multiplied with an exponentiated quadratic to make the seasonality decay as it gets farther from the observations. This seasonal periodic kernel is defined in the code below as
local_periodic_kernel. -
Short to medium term irregularities modelled by a rational quadratic kernel, which is defined in the code below as
irregular_kernel. -
Observational noise which will be modelled directly by adding
observation_noise_varianceto the kernel matrix diagonal.
These different kernels will be summed into one single kernel function $k_{\theta}(x_a, x_b)$ that will allow for all these effects to occur together. This kernel is defined as
kernel
in the code below. Each of the kernels has hyperparameters $\theta$ that can be tuned, they will be defined as unconstrained log-parameters and transformed to positive values before use.
This post
provides more insight on the kernels used here and the effect of their hyperparameters.
# Define the parameters for the kernel function we're fitting
class KernelParams(NamedTuple):
"""Named GP hyperparameters that remain a JAX pytree."""
smooth_amplitude: Array
smooth_length_scale: Array
periodic_amplitude: Array
periodic_length_scale: Array
periodic_period: Array
periodic_local_length_scale: Array
irregular_amplitude: Array
irregular_length_scale: Array
irregular_scale_mixture: Array
observation_noise_variance: Array
# Define the kernel with trainable parameters.
# We optimize unconstrained log-parameters and exponentiate them so
# kernel hyperparameters stay positive.
# Use float64 to reduce numerical issues when computing the
# Cholesky decomposition of the kernel matrix.
JITTER = 1e-6
TINY = np.finfo(np.float64).tiny
def log_param(value: float) -> Array:
"""Initialize a positive hyperparameter in unconstrained log space."""
return jnp.log(jnp.asarray(value, dtype=jnp.float64))
# Initial unconstrained values
params = KernelParams(
smooth_amplitude=log_param(10.0),
smooth_length_scale=log_param(10.0),
periodic_amplitude=log_param(5.0),
periodic_length_scale=log_param(1.0),
periodic_period=log_param(1.0),
periodic_local_length_scale=log_param(1.0),
irregular_amplitude=log_param(1.0),
irregular_length_scale=log_param(1.0),
irregular_scale_mixture=log_param(1.0),
observation_noise_variance=log_param(1.0),
)
def positive(raw_value: Array) -> Array:
"""Map unconstrained optimizer values to positive hyperparameters."""
return jnp.exp(raw_value) + TINY
def transform_params(params: KernelParams) -> KernelParams:
"""Convert learned log-parameters into valid kernel hyperparameters."""
return KernelParams(
smooth_amplitude=positive(raw_value=params.smooth_amplitude),
smooth_length_scale=positive(raw_value=params.smooth_length_scale),
periodic_amplitude=positive(raw_value=params.periodic_amplitude),
periodic_length_scale=positive(raw_value=params.periodic_length_scale),
periodic_period=positive(raw_value=params.periodic_period),
periodic_local_length_scale=positive(
raw_value=params.periodic_local_length_scale
),
irregular_amplitude=positive(raw_value=params.irregular_amplitude),
irregular_length_scale=positive(raw_value=params.irregular_length_scale),
irregular_scale_mixture=positive(raw_value=params.irregular_scale_mixture),
observation_noise_variance=positive(
raw_value=params.observation_noise_variance
),
)
#
# Define kernel functions
# The inputs in this post are one-dimensional dates. Keeping them as
# explicit column vectors makes the pairwise covariance shapes visible.
def as_column(x: Array) -> Array:
"""Represent one-dimensional inputs as column vectors."""
return jnp.asarray(x, dtype=jnp.float64).reshape(-1, 1)
def squared_distance(xa: Array, xb: Array) -> Array:
"""Compute pairwise squared distances for stationary kernels."""
xa = as_column(xa)
xb = as_column(xb)
return jnp.sum((xa[:, None, :] - xb[None, :, :]) ** 2, axis=-1)
def exponentiated_quadratic_kernel(
xa: Array, xb: Array, amplitude: Scalar = 1.0, length_scale: Scalar = 1.0
) -> Array:
"""Model long-term smooth variation with an RBF covariance."""
return amplitude**2 * jnp.exp(
-0.5 * squared_distance(xa=xa, xb=xb) / length_scale**2
)
def rational_quadratic_kernel(
xa: Array,
xb: Array,
amplitude: Scalar = 1.0,
length_scale: Scalar = 1.0,
scale_mixture: Scalar = 1.0,
) -> Array:
"""Model variation across several length scales."""
return amplitude**2 * (
1.0 + squared_distance(xa=xa, xb=xb) / (2.0 * scale_mixture * length_scale**2)
) ** (-scale_mixture)
def periodic_kernel(
xa: Array,
xb: Array,
amplitude: Scalar = 1.0,
length_scale: Scalar = 1.0,
period: Scalar = 1.0,
) -> Array:
"""Model repeating structure with a periodic covariance."""
xa = as_column(xa)
xb = as_column(xb)
distance = jnp.abs(xa[:, None, :] - xb[None, :, :])
sine_squared = jnp.sum(jnp.sin(jnp.pi * distance / period) ** 2, axis=-1)
return amplitude**2 * jnp.exp(-2.0 * sine_squared / length_scale**2)
def local_periodic_kernel(
xa: Array,
xb: Array,
amplitude: Scalar = 1.0,
periodic_length_scale: Scalar = 1.0,
period: Scalar = 1.0,
local_length_scale: Scalar = 1.0,
) -> Array:
"""Model periodic structure that fades with distance."""
return periodic_kernel(
xa=xa,
xb=xb,
amplitude=amplitude,
length_scale=periodic_length_scale,
period=period,
) * exponentiated_quadratic_kernel(
xa=xa, xb=xb, amplitude=1.0, length_scale=local_length_scale
)
def smooth_kernel(xa: Array, xb: Array, positive_params: KernelParams) -> Array:
"""Extract the long-term smooth component of the CO2 model."""
return exponentiated_quadratic_kernel(
xa=xa,
xb=xb,
amplitude=positive_params.smooth_amplitude,
length_scale=positive_params.smooth_length_scale,
)
def irregular_kernel(xa: Array, xb: Array, positive_params: KernelParams) -> Array:
"""Extract the medium-scale irregular component of the CO2 model."""
return rational_quadratic_kernel(
xa=xa,
xb=xb,
amplitude=positive_params.irregular_amplitude,
length_scale=positive_params.irregular_length_scale,
scale_mixture=positive_params.irregular_scale_mixture,
)
#
# Define the combined kernel we're going to be fitting
def combined_kernel(xa: Array, xb: Array, positive_params: KernelParams) -> Array:
"""Combine smooth, seasonal, and irregular CO2 covariance components."""
return (
smooth_kernel(xa=xa, xb=xb, positive_params=positive_params)
+ local_periodic_kernel(
xa=xa,
xb=xb,
amplitude=positive_params.periodic_amplitude,
periodic_length_scale=positive_params.periodic_length_scale,
period=positive_params.periodic_period,
local_length_scale=positive_params.periodic_local_length_scale,
)
+ irregular_kernel(xa=xa, xb=xb, positive_params=positive_params)
)
Tuning the hyperparameters
We can tune the hyperparameters $\theta$ of our Gaussian process model based on the data. This post uses JAX to fit the parameters by maximizing the marginal likelihood $p(\mathbf{y} \mid X, \theta)$ of the Gaussian process distribution based on the observed data $(X, \mathbf{y})$.
$$\hat{\theta} = \underset{\theta}{\text{argmax}} \left( p(\mathbf{y} \mid X, \theta) \right)$$
The marginal likelihood of the Gaussian process is the likelihood of a Gaussian distribution which is defined as:
$$p(\mathbf{y} \mid \mu, \Sigma) = \frac{1}{\sqrt{(2\pi)^d \lvert\Sigma\rvert}} \exp{ \left( -\frac{1}{2}(\mathbf{y} - \mu)^{\top} \Sigma^{-1} (\mathbf{y} - \mu) \right)}$$
The mean is calculated from the parameterized mean function using the observed data $X$ as input: $\mu_{\theta} = m_{\theta}(X)$. For the observed targets, the covariance also includes observation noise on the diagonal: $\Sigma_{\theta} = k_{\theta}(X, X) + \sigma_n^2 I$. We can then write the marginal likelihood as:
$$ p(\mathbf{y} \mid X, \theta) = \frac{1}{\sqrt{(2\pi)^d \lvert \Sigma_{\theta} \rvert}} \exp{ \left( -\frac{1}{2}(\mathbf{y} - \mu_{\theta})^\top \Sigma_{\theta}^{-1} (\mathbf{y} - \mu_{\theta}) \right)}$$
With $d$ the dimensionality of the marginal and $\lvert \Sigma_{\theta} \rvert$ the determinant of the observed-data covariance matrix. We can get rid of the exponent by taking the log and maximizing the log marginal likelihood:
$$ \log{p(\mathbf{y} \mid X, \theta)} = -\frac{1}{2}(\mathbf{y} - \mu_{\theta})^\top \Sigma_{\theta}^{-1} (\mathbf{y} - \mu_{\theta}) - \frac{1}{2} \log{\lvert \Sigma_{\theta} \rvert} - \frac{d}{2} \log{2 \pi}$$
The first term ($-0.5 (\mathbf{y} - \mu_{\theta})^\top \Sigma_{\theta}^{-1} (\mathbf{y} - \mu_{\theta})$) is the data-fit while the rest ($-0.5(\log{\lvert \Sigma_{\theta} \rvert} + d\log{2 \pi})$) is a complexity penalty, also known as differential entropy ⁽¹⁾ .
The optimal parameters $\hat{\theta}$ can then be found by minimizing the negative of the log marginal likelihood:
$$ \hat{\theta} = \underset{\theta}{\text{argmax}} \left( p(\mathbf{y} \mid X, \theta) \right) = \underset{\theta}{\text{argmin}} { \;-\log{ p(\mathbf{y} \mid X, \theta)}}$$
Since in this case the log marginal likelihood is differentiable with respect to the kernel hyperparameters, we can use a gradient-based approach to minimize the negative log marginal likelihood (NLL). In this post we will use JAX automatic differentiation and Optax to train the hyperparameters on the full observed dataset.
# Convert observed and prediction data to JAX arrays
X_observed = jnp.asarray(df_observed.Date.values.reshape(-1, 1), dtype=jnp.float64)
y_observed = jnp.asarray(df_observed.CO2.values, dtype=jnp.float64)
X_predict = jnp.asarray(df_predict.Date.values.reshape(-1, 1), dtype=jnp.float64)
#
We will implement the Gaussian process marginal likelihood directly from the Gaussian distribution formula above.
jax.value_and_grad
evaluates the negative marginal log likelihood and its gradient, while
optax.adam
is used to update the unconstrained log-parameters.
# Implement the Gaussian process marginal likelihood
def solve_cholesky(cholesky: Array, rhs: Array) -> Array:
"""Solve GP linear systems stably through a Cholesky factor."""
return jsp.linalg.solve_triangular(
a=cholesky.T,
b=jsp.linalg.solve_triangular(a=cholesky, b=rhs, lower=True),
lower=False,
)
def gp_negative_log_marginal_likelihood(
params: KernelParams, index_points: Array, observations: Array
) -> Array:
"""Score how well hyperparameters explain the observed CO2 data."""
# Convert unconstrained optimizer values to valid positive hyperparameters.
positive_params = transform_params(params=params)
# Build the prior covariance over observations.
covariance = combined_kernel(
xa=index_points, xb=index_points, positive_params=positive_params
)
# Add observation noise and jitter for a stable Cholesky decomposition.
covariance = covariance + (
positive_params.observation_noise_variance + JITTER
) * jnp.eye(index_points.shape[0], dtype=jnp.float64)
# Center observations around the prior mean before evaluating the likelihood.
centered_observations = observations - mean_fn(index_points=index_points)
# Factor the covariance once and reuse it for solves and the log determinant.
cholesky = jnp.linalg.cholesky(covariance)
# Compute K^-1 y without forming the inverse explicitly.
alpha = solve_cholesky(cholesky=cholesky, rhs=centered_observations)
# Combine data fit, model complexity, and Gaussian normalization terms.
data_fit = 0.5 * jnp.dot(centered_observations, alpha)
complexity_penalty = jnp.sum(jnp.log(jnp.diag(cholesky)))
constant = 0.5 * index_points.shape[0] * jnp.log(2.0 * jnp.pi)
return data_fit + complexity_penalty + constant
#
# Fit hyperparameters
learning_rate = 0.01
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)
# Define loss and gradients
loss_and_grad = jax.value_and_grad(gp_negative_log_marginal_likelihood)
@jax.jit
def optimization_step(
params: KernelParams,
opt_state: optax.OptState,
index_points: Array,
observations: Array,
) -> tuple[KernelParams, optax.OptState, Array]:
"""Take one step toward lower full-data GP marginal likelihood."""
loss, grads = loss_and_grad(
params, index_points=index_points, observations=observations
)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = cast(KernelParams, optax.apply_updates(params, updates))
return params, opt_state, loss
# Training loop
full_nlls: list[tuple[int, float]] = [] # Full data NLL for plotting
nb_iterations = 2500
for i in tqdm(range(nb_iterations)):
params, opt_state, loss = optimization_step(
params, opt_state, index_points=X_observed, observations=y_observed
)
if i % 10 == 0 or i == nb_iterations - 1:
full_nlls.append((i, float(loss)))
# Plot NLL over iterations
nll_source = bokeh.models.ColumnDataSource(
data={
"iteration": [iteration for iteration, _ in full_nlls],
"nll": [value for _, value in full_nlls],
}
)
nll_values = [value for _, value in full_nlls]
fig = bokeh.plotting.figure(
width=600,
height=350,
x_range=bokeh.models.Range1d(start=0, end=nb_iterations),
y_range=bokeh.models.Range1d(start=min(nll_values) - 20, end=max(nll_values) + 20),
)
fig.add_layout(
bokeh.models.Title(
text="Negative Log-Likelihood (NLL) during training", text_font_size="14pt"
),
"above",
)
fig.xaxis.axis_label = "iteration"
fig.yaxis.axis_label = "NLL all observed data"
nll_renderer = fig.line(
x="iteration",
y="nll",
source=nll_source,
legend_label="All observed data",
line_width=2,
line_color="midnightblue",
)
fig.add_tools(
bokeh.models.HoverTool(
renderers=[nll_renderer],
tooltips=[("Iteration", "@iteration"), ("NLL", "@nll{0.00}")],
mode="vline",
line_policy="nearest",
)
)
fig.legend.location = "top_right"
fig.toolbar.autohide = True
bokeh.plotting.show(fig)
#
Despite the complexity term in the log marginal likelihood it is still possible to overfit to the data. While we didn't prevent further overfitting in this post, you could prevent overfitting by adding a regularization term on the hyperparameters ⁽²⁾ , or split the observations in training an validation sets and select the best fit on your validation data from models trained on the training data.
# Show values of parameters found
fitted_params = transform_params(params)
data = [(name, float(getattr(fitted_params, name))) for name in KernelParams._fields]
df_variables = pd.DataFrame(data=data, columns=pd.Index(["Hyperparameters", "Value"]))
display(HTML(df_variables.to_html(index=False, float_format=lambda x: f"{x:.4f}")))
#
The fitted kernel and its components are illustrated in more detail in a follow-up post .
Posterior predictions
With the fitted kernel we can use the Gaussian process conditioning equations directly to make posterior predictions on points after 2010.
The posterior predictions conditioned on the observed data before 2010 are plotted in the figure below together with their 95% prediction interval.
Notice that our model captures the different dataset characteristics such as the trend and seasonality quite well. The predictions start deviating the further away from the observed data the model was conditioned on, together with widening prediction interval.
# Posterior predictions
def gp_posterior_predict(
params: KernelParams,
observation_index_points: Array,
observations: Array,
prediction_index_points: Array,
) -> tuple[Array, Array]:
"""Predict fitted GP mean and uncertainty at new input locations."""
# Convert learned log-space parameters to positive kernel hyperparameters.
positive_params = transform_params(params=params)
# Build the noisy covariance among observed points.
observation_covariance = combined_kernel(
xa=observation_index_points,
xb=observation_index_points,
positive_params=positive_params,
)
observation_covariance = observation_covariance + (
positive_params.observation_noise_variance + JITTER
) * jnp.eye(observation_index_points.shape[0], dtype=jnp.float64)
# Build the cross-covariance between observations and predictions.
cross_covariance = combined_kernel(
xa=observation_index_points,
xb=prediction_index_points,
positive_params=positive_params,
)
# Build the prior covariance among prediction points.
prediction_covariance = combined_kernel(
xa=prediction_index_points,
xb=prediction_index_points,
positive_params=positive_params,
)
# Center observations around the prior mean before conditioning.
centered_observations = observations - mean_fn(
index_points=observation_index_points
)
# Factor the observed covariance and solve K^-1 y stably.
cholesky = jnp.linalg.cholesky(observation_covariance)
alpha = solve_cholesky(cholesky=cholesky, rhs=centered_observations)
# Add the conditional correction to the prior mean.
posterior_mean = (
mean_fn(index_points=prediction_index_points) + cross_covariance.T @ alpha
)
# Remove explained variance from the prediction prior covariance.
v = jsp.linalg.solve_triangular(a=cholesky, b=cross_covariance, lower=True)
posterior_covariance = prediction_covariance - v.T @ v
# Keep tiny numerical negatives from becoming invalid standard deviations.
posterior_variance = jnp.clip(jnp.diag(posterior_covariance), min=0.0)
posterior_std = jnp.sqrt(posterior_variance)
return posterior_mean, posterior_std
posterior_mean_predict, posterior_std_predict = gp_posterior_predict(
params,
observation_index_points=X_observed,
observations=y_observed,
prediction_index_points=X_predict,
)
#
# Plot posterior predictions
# Get posterior predictions
μ = np.asarray(posterior_mean_predict)
σ = np.asarray(posterior_std_predict)
prediction_source = bokeh.models.ColumnDataSource(
data={
"date": df_predict.Date.to_numpy(),
"mean": μ,
"lower": μ - 2 * σ,
"upper": μ + 2 * σ,
}
)
reference_source = bokeh.models.ColumnDataSource(
data={"date": co2_df.Date.to_numpy(), "co2": co2_df.CO2.to_numpy()}
)
# Plot
fig = bokeh.plotting.figure(
width=600,
height=400,
x_range=bokeh.models.Range1d(start=2010, end=2021.3),
y_range=bokeh.models.Range1d(start=384, end=418),
)
fig.xaxis.axis_label = "Date"
fig.yaxis.axis_label = "CO₂ (ppm)"
fig.add_layout(
bokeh.models.Title(
text="Posterior predictions conditioned on observations before 2010.",
text_font_style="italic",
),
"above",
)
fig.add_layout(
bokeh.models.Title(text="Atmospheric CO₂ concentrations", text_font_size="14pt"),
"above",
)
reference_renderer = fig.scatter(
x="date",
y="co2",
source=reference_source,
legend_label="True data (reference)",
size=2,
marker="circle",
line_color="midnightblue",
)
mean_renderer = fig.line(
x="date",
y="mean",
source=prediction_source,
legend_label="μ (predictions)",
line_width=2,
line_color="firebrick",
)
# Prediction interval
interval_renderer = fig.varea(
x="date",
y1="lower",
y2="upper",
source=prediction_source,
color="firebrick",
alpha=0.4,
legend_label="2σ",
)
fig.add_tools(
bokeh.models.HoverTool(
renderers=[reference_renderer],
tooltips=[("Date", "@date{0.00}"), ("CO₂", "@co2{0.00} ppm")],
),
bokeh.models.HoverTool(
renderers=[mean_renderer, interval_renderer],
tooltips=[
("Date", "@date{0.00}"),
("μ", "@mean{0.00} ppm"),
("lower 2σ", "@lower{0.00} ppm"),
("upper 2σ", "@upper{0.00} ppm"),
],
mode="vline",
line_policy="nearest",
),
)
fig.legend.location = "top_left"
fig.toolbar.autohide = True
bokeh.plotting.show(fig)
#
This post illustrated using JAX to combine multiple kernels and fit their hyperparameters on observed data. The fitted model was then used to make posterior predictions.
This post was the second part of a series on Gaussian processes:
Sidenotes
- Note that the determinant $\lvert \Sigma \rvert$ is equal to the product of its eigenvalues , and that $\lvert \Sigma \rvert$ can be interpreted as the volume spanned by the covariance matrix $\Sigma$. Reducing $\lvert \Sigma \rvert$ will thus decrease the dispersion of the points coming from the distribution with covariance matrix $\Sigma$ and reduce the complexity.
- A fully Bayesian extension could define prior distributions for the hyperparameters and sample from the posterior with a JAX-based probabilistic programming library such as NumPyro .
References
- Gaussian Processes for Machine Learning. Chapter 5: Model Selection and Adaptation of Hyperparameters by Carl Edward Rasmussen and Christopher K. I. Williams.
- Gaussian Processes for Regression by Christopher K. I. Williams and Carl Edward Rasmussen.
# Python package versions used
%load_ext watermark
%watermark --python
%watermark --iversions
#
This post is generated from an IPython notebook file. Link to the full IPython notebook file