Introduction to JAX

6. Introduction to JAX#

import jax.numpy as jnp
import numpy as np
from jax import config, jit, value_and_grad
from matplotlib import pyplot as plt
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
@jit
def loss_lik(mu, v):
    b1 = 0.5
    b2 = 0.01
    a1 = 2.0
    a2 = 5.0
    ls = (
        b1 * (mu**2 + v - 2.0 * a1 * mu + a1**2)
        + b2
        * ((mu**3 + 3.0 * mu * v) - 3.0 * a2 * (mu**2 + v) + 3.0 * (a2**2) * mu - a2**3)
        + 4.0 / v
    )
    return ls
@jit
def loss_pre(params):
    (mu, s) = params
    return loss_lik(mu, 1.0 / s) + jnp.log(s) / 2


loss_f_pre = jit(value_and_grad(loss_pre))
def gd(init_params, loss_fun, step_size, num_iters):
    J_history = np.zeros(num_iters + 1)
    mu_hist, s_hist = np.zeros(num_iters + 1), np.zeros(num_iters + 1)  # For plotting

    cur_params = init_params
    for i in range(num_iters):
        (val, g) = loss_fun(cur_params)  # Euclidean gradient
        mu_hist[i] = cur_params[0]
        s_hist[i] = cur_params[1]
        J_history[i] = val

        cur_params = cur_params - step_size * g  # GD

    (val, _) = loss_fun(cur_params)
    J_history[num_iters] = val
    mu_hist[num_iters] = cur_params[0]
    s_hist[num_iters] = cur_params[1]

    return J_history, mu_hist, s_hist
def ngd_pre(init_params, loss_fun, step_size, num_iters):
    J_history = np.zeros(num_iters + 1)
    mu_hist, s_hist = np.zeros(num_iters + 1), np.zeros(num_iters + 1)  # For plotting

    cur_params = init_params
    for i in range(num_iters):
        (mu, s) = cur_params
        (val, (g_mu, g_s)) = loss_fun(cur_params)
        ng = jnp.array([g_mu / s, 2.0 * (s**2) * g_s])  # Natural gradient

        mu_hist[i] = cur_params[0]
        s_hist[i] = cur_params[1]
        J_history[i] = val

        cur_params = cur_params - step_size * ng  # NGD

    (val, _) = loss_fun(cur_params)
    J_history[num_iters] = val
    mu_hist[num_iters] = cur_params[0]
    s_hist[num_iters] = cur_params[1]

    return J_history, mu_hist, s_hist
# Setup of meshgrid of theta values
mu_list, s_list = np.meshgrid(np.linspace(-10, 10, 200), np.logspace(-1, 0.2, 800))

# Computing the cost function for each theta combination
zs = np.array(
    [loss_pre(jnp.array([mu, s])) for mu, s in zip(np.ravel(mu_list), np.ravel(s_list))]
)
Z = zs.reshape(mu_list.shape)
mu_0 = -8.0
s_0 = 1.0
max_num_iters = 200
init_params = jnp.array([mu_0, s_0])
gd_pre_history, mu_gd_pre_hist, s_gd_pre_hist = gd(
    init_params, loss_f_pre, step_size=1e-2, num_iters=max_num_iters
)
anglesx_gd_pre = np.array(mu_gd_pre_hist)[1:] - np.array(mu_gd_pre_hist)[:-1]
anglesy_gd_pre = np.array(s_gd_pre_hist)[1:] - np.array(s_gd_pre_hist)[:-1]
init_params = jnp.array([mu_0, s_0])
ngd_pre_history, mu_ngd_pre_hist, s_ngd_pre_hist = ngd_pre(
    init_params, loss_f_pre, step_size=1e-2, num_iters=max_num_iters
)
anglesx_ngd_pre = np.array(mu_ngd_pre_hist)[1:] - np.array(mu_ngd_pre_hist)[:-1]
anglesy_ngd_pre = np.array(s_ngd_pre_hist)[1:] - np.array(s_ngd_pre_hist)[:-1]
fig = plt.figure(figsize=(16, 8))
ax = fig.add_subplot(1, 2, 1)

ax.contour(mu_list, s_list, Z, 50, cmap="jet")

ax.quiver(
    mu_gd_pre_hist[:-1],
    s_gd_pre_hist[:-1],
    anglesx_gd_pre,
    anglesy_gd_pre,
    label=r"GD $(\mu,s)$",
    scale_units="xy",
    angles="xy",
    scale=1,
    color="g",
    alpha=0.9,
)

ax.quiver(
    mu_ngd_pre_hist[:-1],
    s_ngd_pre_hist[:-1],
    anglesx_ngd_pre,
    anglesy_ngd_pre,
    label=r"NGD $(\mu,s)$",
    scale_units="xy",
    angles="xy",
    scale=1,
    color="r",
    alpha=0.9,
)

ax.set_xlabel(r"$\mu$")
ax.set_ylabel("$s$")
ax.legend(loc="upper right")

ax = fig.add_subplot(1, 2, 2)
ax.plot(
    np.array(list(range(0, max_num_iters + 1))),
    gd_pre_history,
    label=r"GD $(\mu,s)$",
    color="g",
)
ax.plot(
    np.array(list(range(0, max_num_iters + 1))),
    ngd_pre_history,
    label=r"NGD $(\mu,s)$",
    color="r",
)
ax.legend(loc="upper right")
ax.set_xlabel("Number of iters")
ax.set_ylabel("loss")

# plt.tight_layout()
# plt.show()
Text(0, 0.5, 'loss')
../_images/1e1242eb86649107b15d6e8b80e28e141d8245679fb7f5250f83cc0c71e2644a.png