Hacker News new | ask | show | jobs
by jaykmody 1231 days ago
Hey ya'll author here!

Thank you for all the nice and constructive comments!

For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...

The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.

Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:

    def lm_loss(params, inputs, n_head) -> float:
        x, y = inputs[:-1], inputs[1:]
        output = gpt(x, **params, n_head=n_head)
        loss = np.mean(-np.log(output[y]))
        return loss
  
    grads = jax.grad(lm_loss)(params, inputs, n_head)
You can even support batching with `jax.vmap` (https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.h...):

    gpt2_batched = jax.vmap(gpt2, in_axes=0)
    gpt2_batched(batched_inputs) # [batch, seq_len] -> [batch, seq_len, vocab]
Of course, with JAX comes in-built GPU and even TPU support!

As far as training code and KV Cache for inference efficiency, I leave that as an exercise for the reader lol

8 comments

"hackable" and "simple yet complete technical introduction"

Music to my ears, well done and don't worry too much about the negative comments! They'll come out for anything you do I think.

I saw a tweet from someone the other day talking about how they massively increased their training speed by changing part of their architecture to have dimensions that were a factor of 64 rather than a prime-like kind of number.

One of the comments below it? ~"Seems very architecture specific."

lol.

So don't sweat it! <3 Great work and thanks for putting yourself out there, super job! :D :D :D :D :)))))) <3 :D :D :fireworks:

We do GPU-specific training and inference speedups, at CentML.
Grata, well deserved.
This is beautiful. Having worked with everything from nanoGPT to Megatron, sitting down and reading through picoGPT.py was clear and refreshing with just the essential details. Nothing left to add, nothing left to take away: perfection.
This looks like something Peter Norvig would write, and that’s about the highest compliment I can give.
> GPU support

If you haven't tried cuNumeric [1], you really ought to. It's a drop-in NumPy wrapper for distributed GPU acceleration. Would be interesting to see if it works for this.

[1]: https://github.com/nv-legate/cunumeric

The problem with drop-in replacements between CPU and GPU code is that performance GPU code requires rethinking the dataflow often -- so even if the code itself is a drop-in, the "make it good" part still requires some rewriting.

I'd be curious how that library compares to other numeric python GPU libraries

> For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...

Neat, but please add one-line comments/docstrings where these missing bits would go.

Hi there, thank you for putting this together !

I want to commend you for one of the best written introductions in this space that I've seen, especially the excellent use of hyperlinking that points to really good resources exactly at the right time !

Hope it move to like open go ai version. Alpha go comes and goes. We need one and open sources we have one. Hope this is the same.
Tteam5049@gmail.com