You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After my cuDNN Flash Attention implementation was integrated yesterday, I spent some time profiling and trying to figure out how much more we can improve performance short/medium-term, while also thinking of longer-term improvements I'd be keen to work on (probably after my exams at the end of the month unfortunately).
This is the Nsight Compute profile data for a A100 40GB SXM4 (on Lambda) as of the last commit yesterday (4dd1ab4), obtained from profile_gpt2cu.py with a forced batch size of 24. Both BF16 and cuDNN are enabled. Kernels are listed from worst to best DRAM/memory utilisation.
But this doesn’t tell us the relative importance of these kernels, as some are 4x per layer (e.g. matmul_backward_bias) while others are run only once per training step (e.g. fused classifier). Also, caches are flushed between kernels, so these numbers are a “worst case” (not flushing the caches would be way too optimistic however).
The profile_gpt2cu.py script tries to generate a report that accounts for this, but it is fragile and needs fixing due to cuDNN changing the number of kernels (in a GPU-dependent way). For the most accurate timing data, it is best (but more time consuming) to manually analyse performance using Nsight Systems.
This is the Nsight Systems timeline for train_gpt2cu under similar conditions (batch size was probably 16 rather than 24 iirc, but it doesn’t make a huge difference) taking ~100ms per step:
Here's the forward pass which takes ~1.6ms per layer (with 0.15ms for the gelu):
And the backward pass which takes ~4.6ms per layer (with ~0.25ms to ~0.3ms for each of the 4 matmul_backward_bias):
Our short-term target should be a strict minimum of 70% of peak compute or 70% of peak DRAM bandwidth for all kernels, with a stretch goal of >85%. The x128 changes (allowing us to pack 8 BF16 inputs into a single load instruction and loop iteration) will help a lot with many kernels but won’t always be enough. But I can't think of any good reason why we shouldn't be able to achieve that target for all the key kernels, although some optimisations risk making the kernel a lot longer and harder to understand.
Where possible, the required amount of compute and DRAM bandwidth can also be reduced by improving the kernels, e.g.:
Reduce overhead: For layernorm forward, the ~76% compute intensity is dominated by loop overhead, x128 datatype with loop unrolling will significantly reduce compute %, which means the actual goal is improving Memory % from ~33% to ~75%, i.e. 2.2x performance. In that case achieving this goal is likely to require a little bit more than just x128, but it may be enough to get close.
More kernel fusion: GELU and other element-wise operations can theoretically be merged into the matmul, and it may be possible to merge other functions like residual forward and layernorm forward to save more DRAM bandwidth. Eventually if we make our own fully custom matmuls (or use CUTLASS/cuBLASDx) then it might be possible to do unusually complicated (partial?) fusions of matmuls with row operations, e.g. matmul+layernorm.
Smarter use of L2 cache: By changing the cache persistence settings and/or the order operations are computed, it may be possible to hit in the L2 cache some of the time even for large kernels, e.g. 200MiB working set with a 40MiB cache, with the right cache controls you might be able to save ~10% DRAM, and if you can order things the right way potentially much more. This is an advanced topic that will be discussed separately.
Kernel By Kernel (Short-Term Improvements)
matmul_backward_bias_kernel4: >3x speed-up possible, potentially up to ~10% performance improvement (definitely >>5%). Most complex kernel to optimise, it needs x128 but also structural changes including global atomics because right now it only has a very small number of threadblocks which means most SMs are completely idle.
fused_classifier_kernel3: >2x speed-up possible, ~3% performance improvement, x128/unroll will help but may need more tricks to hit DRAM peak.
residual+layernorm_forward: ~2x speed-up possible if we get to 70% of DRAM peak, or >2.5x with fusion (save 1 read but no writes), up to ~2% performance improvement.
gelu_backward: ~2x speed-up by getting to 70% of DRAM peak (x128/loop unrolling should lift compute bottleneck), >2% performance improvement.
gelu_forward: ~2x speed-up by merging into matmul (saving read but not write), ~1% performance improvement.
TOTAL: Up to ~20% performance that we should be able to unlock very quickly!
After that, we might want to focus on many-GPU efficiency by parallelising NCCL communication with individual layers of backward pass, but unfortunately this will probably require changing the memory layout of the weights/gradients.
Medium/Long-Term Ideas
Besides the obvious goal of eventually supporting FP8 for the entire network and associated tricks to achieve good training loss with it, there's still plenty of things we could do to improve BF16/FP16 performance!
These are just my own personal ideas of things I might be interested personally working on:
Automatic sweeping of L2 cache residency per-kernel to find the optimal settings that minimise the required DRAM bandwidth (+ use of cuda::discard_memory?)
FP8 for attention only with dynamic scaling of tensors, should be easier & less fragile than FP8 for the entire network (still built on top of cuDNN for now...)
Replace all cuBLAS Matrix Multiplication and cuDNN Flash Attention kernels by fully custom H100-only optimised kernels with inline PTX (+SASS?!)
And some more random ideas off to top of my head which also seem interesting:
Investigate specialising kernels at compilation time for higher efficiency (e.g. knowing vocabulary size and even batch size).
Find ways to merge even more into matmuls, e.g. “row operations” like forward layernorm merged into preceding matmul.
Implement “unit scaling” with static per-tensor scaling factors determined mathematically (rather than dynamically).
Lossless HW compressible memory support with lossy kernels that dynamically round numbers very close to 0.
FP6 or FP4 support for the encodings to optimise 128K Llama3 vocab on small networks (emulated pre-B100).
… although I’m very keen to actually play with some more researchy/algorithmic ideas I’ve had rather than only focus on raw performance for standard architectures :) We are definitely getting to the point where this should already be useful for real research work on small networks!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
After my cuDNN Flash Attention implementation was integrated yesterday, I spent some time profiling and trying to figure out how much more we can improve performance short/medium-term, while also thinking of longer-term improvements I'd be keen to work on (probably after my exams at the end of the month unfortunately).
This is the Nsight Compute profile data for a A100 40GB SXM4 (on Lambda) as of the last commit yesterday (4dd1ab4), obtained from profile_gpt2cu.py with a forced batch size of 24. Both BF16 and cuDNN are enabled. Kernels are listed from worst to best DRAM/memory utilisation.
But this doesn’t tell us the relative importance of these kernels, as some are 4x per layer (e.g. matmul_backward_bias) while others are run only once per training step (e.g. fused classifier). Also, caches are flushed between kernels, so these numbers are a “worst case” (not flushing the caches would be way too optimistic however).
The profile_gpt2cu.py script tries to generate a report that accounts for this, but it is fragile and needs fixing due to cuDNN changing the number of kernels (in a GPU-dependent way). For the most accurate timing data, it is best (but more time consuming) to manually analyse performance using Nsight Systems.
This is the Nsight Systems timeline for train_gpt2cu under similar conditions (batch size was probably 16 rather than 24 iirc, but it doesn’t make a huge difference) taking ~100ms per step:
Here's the forward pass which takes ~1.6ms per layer (with 0.15ms for the gelu):
And the backward pass which takes ~4.6ms per layer (with ~0.25ms to ~0.3ms for each of the 4 matmul_backward_bias):
Our short-term target should be a strict minimum of 70% of peak compute or 70% of peak DRAM bandwidth for all kernels, with a stretch goal of >85%. The x128 changes (allowing us to pack 8 BF16 inputs into a single load instruction and loop iteration) will help a lot with many kernels but won’t always be enough. But I can't think of any good reason why we shouldn't be able to achieve that target for all the key kernels, although some optimisations risk making the kernel a lot longer and harder to understand.
Where possible, the required amount of compute and DRAM bandwidth can also be reduced by improving the kernels, e.g.:
Kernel By Kernel (Short-Term Improvements)
TOTAL: Up to ~20% performance that we should be able to unlock very quickly!
After that, we might want to focus on many-GPU efficiency by parallelising NCCL communication with individual layers of backward pass, but unfortunately this will probably require changing the memory layout of the weights/gradients.
Medium/Long-Term Ideas
Besides the obvious goal of eventually supporting FP8 for the entire network and associated tricks to achieve good training loss with it, there's still plenty of things we could do to improve BF16/FP16 performance!
These are just my own personal ideas of things I might be interested personally working on:
And some more random ideas off to top of my head which also seem interesting:
… although I’m very keen to actually play with some more researchy/algorithmic ideas I’ve had rather than only focus on raw performance for standard architectures :) We are definitely getting to the point where this should already be useful for real research work on small networks!
Beta Was this translation helpful? Give feedback.
All reactions