JAX
JAX is a young ML ecosystem that provides simple tools than can be composed to solve complex problems. It provides Numpy-like functionality on the GPU with grad and vectorization methods. Below we'll use JAX to estimate a market model.
Load libraries:
import pandas as pd import yfinance as yf import jax import jax.numpy as jnp from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split
Get data from Yahoo Finance, convert them into returns, and split into training and test data sets.
market = yf.Ticker("SPY") market_data = market.history(period="1y") target = yf.Ticker("GME") target_data = target.history(period="1y") market_returns = market_data["Close"].pct_change()[1:] target_returns = target_data["Close"].pct_change()[1:] X, X_test, y, y_test = train_test_split(market_returns.values, target_returns.values)
Model
# model weights params = { 'w': 1., #jnp.zeros(X.shape[1:]), 'b': 0. } def forward(params, X): return jnp.dot(X, params['w']) + params['b'] # calculates y_hat def loss_fn(params, X, y): err = forward(params, X) - y # y_hat - y return jnp.mean(jnp.square(err)) # mse grad_fn = jax.grad(loss_fn) def update(params, grads): return jax.tree_multimap(lambda p, g: p - 0.9 * g, params, grads) # the main training loop for _ in range(50000): loss = loss_fn(params, X_test, y_test) # print(loss) grads = grad_fn(params, X, y) params = update(params, grads) # print(params)
print(params)
{'b': DeviceArray(0.00072088, dtype=float32, weak_type=True), 'w': DeviceArray(2.2803657, dtype=float32, weak_type=True)}
This has a beta of 2.28
, which is reasonable for GME.
Check results
from sklearn.linear_model import LinearRegression reg = LinearRegression().fit(X.reshape(-1,1), y) print(reg.coef_)
[2.28085816]
So a beta of 2.28
, looks good.