Looking at the docs, in reality, most of the time you want this to call out to FA2 which optimizes the kernals on the device to split ops on the Softmax of the triangular matrix as well as reduce moving unnecessary batches of floating point numbers back and forth from the GPU to the CPU.