| Granted that PyTorch and TensorFlow both heavily use the same CUDA/cuDNN components under the hood (with TF also having a billion other non-deep learning-centric components included), I think one of the primary reasons that PyTorch is getting such heavy adoption is that it is a Python library first and foremost. There're maybe all of two "surprises" I've encountered in all my time using it, if even (1. Gradients are accumulated in state, 2. nn.Module does funky things with attributes, so use something like nn.ModuleDict if you're going to be dynamically setting modules). Everything else works like a dream, and works almost exactly how you expected. Model parameters? .parameters() gives you a dict-friendly generator of tensors.
Model state? .state_dict() is a dictionary.
Loading model state? load_state_dict(state_dict)... just loads a dictionary.
Reusing modules across different modules? Just assign them!
Determining what parameters to optimize? Just ... give the list of parameters to the optimizer. You can use all your using Python development and debugging tools, and it feels 100% natural. I can fit it into other Python workflows without making the whole program centered around TensorFlow. TensorFlow is undoubtedly powerful, and if you have the time/resources to put into a static-ish TensorFlow-centric workflow, it could pay off many times over. But it definitely feels like learning an entirely new language, with an entirely different debugging pattern. And furthermore, a language that is constantly changing patterns and best practices, other than super-standard Keras examples. To put into context, even running the official TensorFlow models repository has deprecation warnings. Whereas torchvision works like seamlessly and reads like a reference for writing PyTorch model code. There is just a developer-centric focus to PyTorch that makes it a joy to use. |
I was able to create a custom detection network for a 3-class problem, load up the COCO pretrained weights for the network, strip out all the other weights at the "head" for all the other COCO classes except for the "person" class and then fine-tune the model on my custom 3-class dataset. The resulting model generalized exceptionally well on people as it was still able to retain a lot of its performance from the COCO pre-training. It was so easy to do all of this. Literally, maybe 10 lines of code, and so easy to figure out since I could introspect the state_dict and the weights file directly in my PyCharm interpreter while working out how to do this.