|
Hey ya'll author here! Thank you for all the nice and constructive comments! For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ... The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning. Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable: def lm_loss(params, inputs, n_head) -> float:
x, y = inputs[:-1], inputs[1:]
output = gpt(x, **params, n_head=n_head)
loss = np.mean(-np.log(output[y]))
return loss
grads = jax.grad(lm_loss)(params, inputs, n_head)
You can even support batching with `jax.vmap` (https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.h...): gpt2_batched = jax.vmap(gpt2, in_axes=0)
gpt2_batched(batched_inputs) # [batch, seq_len] -> [batch, seq_len, vocab]
Of course, with JAX comes in-built GPU and even TPU support!As far as training code and KV Cache for inference efficiency, I leave that as an exercise for the reader lol |
Music to my ears, well done and don't worry too much about the negative comments! They'll come out for anything you do I think.
I saw a tweet from someone the other day talking about how they massively increased their training speed by changing part of their architecture to have dimensions that were a factor of 64 rather than a prime-like kind of number.
One of the comments below it? ~"Seems very architecture specific."
lol.
So don't sweat it! <3 Great work and thanks for putting yourself out there, super job! :D :D :D :D :)))))) <3 :D :D :fireworks: