|
|
|
|
|
by jampekka
1936 days ago
|
|
OTOH PyTorch seems to be highly explosive if you try to use it outside the mainstream use (i.e. neural networks). There's sadly no performant autodiff system for general purpose Python. Numba is fine for performance, but does not support autodiff. JAX aims to be sort of general purpose, but in practice it is quite explosive when doing something other than neural networks. A lot of this is probably due to supporting CPUs and GPUs with the same interface. There are quite profound differences in how CPUs and GPUs are programmed, so the interface tends to restrict especially more "CPU-oriented" approaches. I have nothing against supporting GPUs (although I think their use is overrated and most people would do fine with CPUs), but Python really needs a general purpose, high performance autodiff. |
|
As someone who works with machine learning models day-to-day (yes, some deep NNs, but also other stuff) - GPUs really seem unbeatable to me for anything gradient-optimization-of-matrices (i.e. like 80% of what I do) related. Even inference in a relatively simple image classification net takes an order of magnitude longer on CPU than GPU on the smallest dataset I'm working with.
Was this a comment about specific models that have a reputation as being more difficult to optimize on the GPU (like tree-based models - although Microsoft is working in this space)? Or am I genuinely missing some optimization techniques that might let me make more use of our CPU compute?