15  Introduction to Jax

image

import jax.numpy as jnp
import jax
from jax import grad, jit, vmap
from jax import random
from jax import device_put
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
x = device_put(x) # copy values back to CPU
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
1.63 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
360 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

15.1 Automatic differentition

  1. grad()
  2. jax.vjp: reverse-mode vetor-jacobian products
  3. jax.jvp: Forward-mode jacobian-vector products

15.2 Index update

x = jnp.arange(10)
y = x.at[0].set(0)
jax_array = jnp.ones((5,6))
print(jax_array)
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new_jax_array = jax_array.at[::2, 3:].add(7.)
print(new_jax_array)
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]

15.3 Linear regression

import numpy as np
import matplotlib.pyplot as plt
xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise

plt.scatter(xs, ys);

def model(theta, x):
    w, b = theta
    return w*x + b

def loss_fn(theta, x, y):
    prediction = model(theta,x)
    return jnp.mean((prediction - y)**2)

def update(theta, x, y, lr=0.1):
    return theta - lr * jax.grad(loss_fn)(theta, x, y)
theta = jnp.array([1.,1.])

for _ in range(1000):
    theta = update(theta, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs))

w, b = theta