JAX

JAX is a young ML ecosystem that provides simple tools than can be composed to solve complex problems. It provides Numpy-like functionality on the GPU with grad and vectorization methods. Below we'll use JAX to estimate a market model.

Load libraries:

import pandas as pd
import yfinance as yf
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

Get data from Yahoo Finance, convert them into returns, and split into training and test data sets.

market = yf.Ticker("SPY")
market_data = market.history(period="1y")

target = yf.Ticker("GME")
target_data = target.history(period="1y")

market_returns = market_data["Close"].pct_change()[1:]
target_returns = target_data["Close"].pct_change()[1:]

X, X_test, y, y_test = train_test_split(market_returns.values, target_returns.values)

Model

# model weights
params = {
    'w': 1., #jnp.zeros(X.shape[1:]),
    'b': 0.
}


def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']  # calculates y_hat


def loss_fn(params, X, y):
    err = forward(params, X) - y  # y_hat - y
    return jnp.mean(jnp.square(err))  # mse


grad_fn = jax.grad(loss_fn)


def update(params, grads):
    return jax.tree_multimap(lambda p, g: p - 0.9 * g, params, grads)


# the main training loop
for _ in range(50000):
    loss = loss_fn(params, X_test, y_test)
#    print(loss)

    grads = grad_fn(params, X, y)
    params = update(params, grads)
#    print(params)
print(params)
{'b': DeviceArray(0.00072088, dtype=float32, weak_type=True), 'w': DeviceArray(2.2803657, dtype=float32, weak_type=True)}

This has a beta of 2.28, which is reasonable for GME.

Check results

from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X.reshape(-1,1), y)
print(reg.coef_)
[2.28085816]

So a beta of 2.28, looks good.

Author: Matt Brigida, Ph.D.

Created: 2022-08-10 Wed 11:50

Validate