If possible, it would be interesting to explore ways to overcome the memory constraints and run a JIT-compiled version. This could potentially lead to further performance improvements.
+1, we still have a lot of performance we can extract! JIT-compiled train steps, more optimized data loading and sharding, gradient accumulation, and activation checkpointing. We will continue building and will do another blog soon after implementing all the improvements!