Hacker News new | ask | show | jobs
by shoyer 1973 days ago
As someone who builds neural networks routinely, this sort of non-reproducibility sounds troubling to me. We expect small differences for floating point arithmetic between platforms, but integer math is typically exact.

This is all the more concerning for 8-bit quantized arithmetic, where off-by-one means a relative error of about half a percent. If a individual layers in a quantized neural net have off-by-one errors with a consistent bias, I can imagine these errors accumulating into significant losses in model quality in deep networks. There isn't a huge margin for error in quantized neural nets.

One concern about the article: it uses the word "non-deterministic" in a slightly misleading way. I assume any specific hardware is still expected to produce consistent results when run twice on the same input. So it's more non-reproducible than non-deterministic. Compensating for inconsistent arithmetic on different devices sounds much more feasible than compensating for stochastic arithmetic.

3 comments

Thanks for your comments. Regarding determinism, potentially a fair point. Here are a few comments: (1) A driver which randomly produces different output when running the network would be valid according to these restrictions. (2) It is conceivable that a driver would produce non-deterministic input with the same hardware. One commonly known example is that tensorflow will run multiple different convolution kernels and then choose the fastest one. In that case, you can run the same network on the same hardware and get slightly different results. Its not that hard to imagine that a mobile driver could do something similar. (3) It's not true that specific hardware will produce consistent results on the same input. You can run a model today, the driver gets updated, and tomorrow you get different output. This happens frequently.
All good points! "Non-deterministic" behavior within the same program/process is still a bridge I would not want to cross. This could result in subtle glitches, e.g., when a user hits "refresh" with the same inputs, and could make reproducing bugs impossible.

I am a strong believer in always using a seed for random number generation for exactly these sorts of reasons. (Side note: deterministic RNGs is one of my favorite features about JAX.)

You are paying with performance for the determinism. Any DL framework can be made deterministic (just add few lines of configuration), not just JAX.
Regarding bias: This is exactly true especially with the authors method, as the learned quantization ranges are fixed and accumulating biases would lead to the entire batch being clipped to 0 or 255, depending on the direction of the biases. Luckily the bias parameters are kept in int32, so the overall bias produced by them will be much smaller than 2 pct. The arithmetic errors of the int8 matmults are summed within matmul, and are therefore an unbiased estimate of the true entry in the result matrix.
> We expect small differences for floating point arithmetic between platforms, but integer math is typically exact.

Then perhaps think of the integers as fixed point numbers.