Hacker News new | ask | show | jobs
by germanjoey 641 days ago
How are you verifying accuracy for your JAX port of Llama 3.1?

IMHO, the main reason to use pytorch is actually that the original model used pytorch. What can seem to be identical logic between different model versions may actually cause model drift when infinitesimal floating point errors accumulate due to the huge scale of the data. My experience is that debugging an accuracy mismatches like this in a big model is a torturous ordeal beyond the 10th circle of hell.

3 comments

Good question. We used a new AI+math-based testing tool (benchify.com) to run comparison tests, but we are working on building more robust infrastructure for this. Translating models from PyTorch to JAX is core to our strategy.

That said, this path is not uncommon (translating from one framework to another). HuggingFace translates Google's Gemma family models from JAX to PyTorch, and a ton of people use it.

When you say "model versions", do you mean different quantizations of the model? Then it's not floating point errors that accumulate. Different quantizations of the model are different models. People will call such a model something like Meta-Llama-3.1-8B-Instruct--q4_0, claiming that it's just a "version" of the Meta-Llama-3.1-8B-Instruct. But it's just a lie. It's not the same model, and you should not expect the same results. There is no reason to debug the differences, what exactly would you expect to find, and what action would you envision to take once you find what you are looking for? However, is the quantized version still a useful LLM? Absolutely. Most people don't have an A100 to run the original model, so a quantized version is better than nothing.
Very fascinating, can you explain more about a time when this happened?

Like what area was affected by fp errors, why were they introduced (was it like refactoring of pytorch code?), how was this determined to be the cause?