- Scaled Dot Product Attention FP16/BF16 Forward
- Scaled Dot Product Attention FP16/BF16 Backward
- Scaled Dot Product Attention FP8 Forward
- Scaled Dot Product Attention FP8 Backward
- Supported Tensor Layouts
This operation computes the scaled dot product attention (SDPA), as
using the FlashAttention-2 algorithm as described in the paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation.
-
Python sample: samples/python/50_scaled_dot_product_attention.ipynb
-
Python sample with paged caches: samples/python/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb
-
C++ sample: samples/cpp/sdpa
-
Python tests: test/python/test_mhas.py
- Attention scale (
attn_scale
): Applies a scaling factor to attention scores before the softmax, such as$\frac{1}{\sqrt{\text{d}}}$ . Set to 1.0 by default. - Bias mask: Applies an additive bias mask to attention scores. Users must pass a bias tensor as specified in the tensors section below. The dimensions that are passed as 1 will apply a broadcasted mask over attention scores.
- Alibi mask: Attention with Linear Biases (ALiBi) is an additive mask applied to the attention scores as described in the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.
- Padding mask: Also called variable sequence length, this option masks out padded time steps to ignore them in computation. Users must pass a per-batch sequence length as specified in the tensors section below.
- Causal mask: Fills the upper triangular matrix of attention scores with negative infinity.
- Sliding window mask: Allows computation of attention scores from (pos-sliding_window_length, pos] for every position
pos
. Fills rest of the entries in the matrix with negative infinity. - Dropout: Randomly zeros some of the attention weights after the softmax as a form of regularization.
Users can configure dropout in two ways:
- To use the more performant Philox RNG dropout implementation, users must provide:
- An RNG seed, passed as a cudnn tensor.
- An RNG offset, passed as a cudnn tensor.
- A float representing the dropout probability, which is the probability that any given weight is set to zero.
- (Debug only) Output RNG dump generated by the Philox RNG, passed as a cuDNN tensor.
- To use an user-provided dropout mask, users must provide:
-
dropout mask
that matches the attention weights' dimensions, indicating which weights to drop. The dimensions that are passed as 1 will apply a broadcasted dropout mask. -
dropout scale
used to adjust the scale of the remaining weights accordingly, such as$1 / (1 - \text{dropout probability})$ .
-
- To use the more performant Philox RNG dropout implementation, users must provide:
- Packed layout: With packed layout, the query, key, value, and output tensor should be ragged tensors, which are tensors with nested variable length lists as inner dimensions. Users must pass another tensor called ragged offset tensor using the
Tensor_attributes.set_ragged_offset()
method. the ragged offset tensor must be a tensor of size$(B + 1, 1, 1, 1)$ that contains the nested tensor's offset in terms of number of elements (not bytes). The last value of the offset tensor specifies the offset of the past-the-end element of the ragged tensor. See Appendix A for more information on the supported layouts. - Paged attention: with paged K and/or V caches, the K/V blocks no longer need to be contiguous, allowing users to better utilize memory by avoiding fragmentation.
- Users must therefore:
- Pass a
page table k
tensor containing offsets to the container with K blocks. This is optional, and only needed if the K cache is paged. - Pass a
page table v
tensor containing offsets to the container with V blocks. This is optional, and only needed if the V cache is paged. - Pass anything required for
Padding mask
above (i.e., per-batch sequence lengths for both K and V caches). This is needed if at least one of the K/V caches are paged. - Optionally, but recommended, pass the maximum sequence length for the K/V caches. When omitted, it will be (over)estimated, which could result in a corrupted graph in some corner cases.
- Pass a
- Offsets to the K/V containers will be calculcated as
$Kcache[b,h,s,d] = K[page\ table\ k[b,1,s / bs_k, 1],h,s\ mod\ bs_{k},d]$ $Vcache[b,h,s,d] = V[page\ table\ v[b,1,s / bs_v, 1],h,s\ mod\ bs_{v},d]$
- See also the PagedAttention paper.
- Users must therefore:
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
Q | GPU | FP16 or BF16 | |
K | GPU | FP16 or BF16 |
|
V | GPU | FP16 or BF16 |
|
(Bias mask) Bias Mask | GPU | FP16 or BF16 |
|
(Padding mask/Paged Caches) Sequence Length Q | GPU | INT32 | |
(Padding mask/Paged Caches) Sequence Length KV | GPU | INT32 | |
(Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | |
(Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | |
(Custom Dropout Mask) Mask | GPU | FP16 or BF16 |
|
(Custom Dropout Mask) Scale | GPU | FP32 | |
(Packed Layout) Ragged Offset | GPU | INT32 | |
(Paged Attention) Page Table K | GPU | INT32 | |
(Paged Attention) Page Table V | GPU | INT32 | |
(Paged Attention) Max Sequence Length KV | CPU | INT32 or INT64 |
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
O | GPU | FP16 or BF16 | |
Stats (training only) | GPU | FP32 | |
(Philoc RNG Dropout) RNG Dump | GPU | FP32 |
Where,
-
$B$ is the batch size -
$H_{q}$ is the number of query heads -
$H_{k}$ is the number of key heads -
$H_{v}$ is the number of value heads -
$S_{q}$ is the sequence length of the query -
$S_{kv}$ is the sequence length of the key and value -
$D_{qk}$ is the embedding dimension per head of query and key -
$D_{v}$ is the embedding dimension per head of value -
$bs_{k}$ is the (power of 2) block size of the K container -
$bs_{v}$ is the (power of 2) block size of the V container -
$num_blocks_{k}$ is the number of blocks in the K container -
$num_blocks_{v}$ is the number of blocks in the V container
- As described in the paper GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,
- When
$H_{k}$ and$H_{v}$ is less than$H_{q}$ and factors of$H_{q}$ , this operation will perform group-query attention (GQA) computation. - When
$H_{k}$ and$H_{v}$ are both set to 1, this operation perform multi-query attention (MQA) computation.
- All input and output tensor datatypes must be float16 or bfloat16 datatype except the softmax stats output tensor, which must be float32.
- The dimension of the embedding dimension per head
$D_{qk}$ and$D_{v}$ must be a multiple of 8 with maximum value 128. - the stride of the embedding dimension per head
$D_{qk}$ and$D_{v}$ for all the tensors above must be 1. - this operation is only supported on GPUs with NVIDIA Ampere architecture (SM80) or newer.
// returns [output, softmax_stats]
std::array<std::shared_ptr<Tensor_attributes>, 2>
sdpa(std::shared_ptr<Tensor_attributes> q,
std::shared_ptr<Tensor_attributes> k,
std::shared_ptr<Tensor_attributes> v,
SDPA_attributes options);
The options
parameter of type SDPA_attributes
is used to control the attributes of the forward operation, as detailed below:
SDPA_attributes&
set_is_inference(bool const value);
SDPA_attributes&
set_attn_scale(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_attn_scale(float const value);
SDPA_attributes&
set_bias(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_alibi_mask(bool const value);
SDPA_attributes&
set_padding_mask(bool const value);
SDPA_attributes&
set_seq_len_q(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_seq_len_kv(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_causal_mask(bool const value);
SDPA_attributes &
set_sliding_window_length(int const value);
SDPA_attributes &
set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> seed,
std::shared_ptr<Tensor_attributes> offset);
// for debugging dropout mask
SDPA_attributes&
set_rng_dump(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_dropout(std::shared_ptr<Tensor_attributes> mask,
std::shared_ptr<Tensor_attributes> scale);
SDPA_attributes&
set_compute_data_type(DataType_t value);
SDPA_attributes&
set_paged_attention_k_table(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_paged_attention_v_table(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_paged_attention_max_seq_len_kv(int const value);
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data. When page_table_k is provided, 'k' is a container of non-contiguous key data.
v (cudnn_tensor): The value data. When page_table_v is provided, 'v' is a container of non-contiguous value data.
is_inference (bool): Whether it is an inference step or training step.
attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None.
bias (Optional[cudnn_tensor]): The bias data for attention. Default is None.
use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False.
use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False.
seq_len_q (Optional[cudnn_tensor]): The sequence length of the query.
seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None.
rng_dump (Optional[cudnn_tensor]): Debug tensor used to output the Philox RNG dropout mask
paged_attention_k_table (Optional[cudnn_tensor]): The page table to look up offsets into 'k'
paged_attention_v_table (Optional[cudnn_tensor]): The page table to look up offsets into 'v'
paged_attention_max_seq_len_kv (Optional[integer]): The maximum sequence length for k/v caches when paged attention is active.
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.
Returns:
o (cudnn_tensor): The output data.
stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step.
This operation computes gradient tensors for scaled dot product attention (SDPA) using the FlashAttention-2 algorithm as described in the paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. The user is required to pass the stats tensor from the forward operation to the backward operation as input.
-
Python sample: samples/python/51_scaled_dot_product_attention_backward.ipynb
-
C++ sample: samples/cpp/sdpa
-
Python tests: test/python/test_mhas.py
All the options mentioned in the forward operation, including ragged tensors and GQA/MQA, are applicable in the backward operation as well.
All the tensor requirements described in the forward operation are applicable in the backward operation as well. The gradient tensors for query, key, value, output, and bias should have the same properties as their non-gradient counterparts.
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
dO | GPU | FP16 or BF16 |
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
dQ | GPU | FP16 or BF16 | |
dK | GPU | FP16 or BF16 | |
dV | GPU | FP16 or BF16 |
All the limitations mentioned in the forward operation are applicable in the backward operation as well.
// returns [dQ, dK, dV]
std::array<std::shared_ptr<Tensor_attributes>, 3>
sdpa_backward(std::shared_ptr<Tensor_attributes> q,
std::shared_ptr<Tensor_attributes> k,
std::shared_ptr<Tensor_attributes> v,
std::shared_ptr<Tensor_attributes> o,
std::shared_ptr<Tensor_attributes> dO,
std::shared_ptr<Tensor_attributes> stats,
SDPA_backward_attributes);
The options
parameter of type SDPA_backward_attributes
is used to control the attributes of backward operation, as detailed below:
SDPA_backward_attributes&
set_attn_scale(std::shared_ptr<Tensor_attributes> value);
SDPA_backward_attributes&
set_attn_scale(float const value);
SDPA_backward_attributes&
set_bias(std::shared_ptr<Tensor_attributes> value);
SDPA_backward_attributes&
set_dbias(std::shared_ptr<Tensor_attributes> value);
SDPA_backward_attributes&
set_alibi_mask(bool const value);
SDPA_backward_attributes&
set_padding_mask(bool const value);
SDPA_backward_attributes&
set_seq_len_q(std::shared_ptr<Tensor_attributes> value);
SDPA_backward_attributes&
set_seq_len_kv(std::shared_ptr<Tensor_attributes> value);
SDPA_backward_attributes&
set_causal_mask(bool const value);
SDPA_backward_attributes &
set_sliding_window_length(int const value);
SDPA_backward_attributes&
set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> seed,
std::shared_ptr<Tensor_attributes> offset);
// for debugging dropout mask
SDPA_backward_attributes&
set_rng_dump(std::shared_ptr<Tensor_attributes> value);
SDPA_backward_attributes&
set_dropout(std::shared_ptr<Tensor_attributes> mask,
std::shared_ptr<Tensor_attributes> scale,
std::shared_ptr<Tensor_attributes> scale_inv);
SDPA_backward_attributes&
set_deterministic_algorithm(bool const value);
SDPA_backward_attributes&
set_compute_data_type(DataType_t const value);
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data.
v (cudnn_tensor): The value data.
o (cudnn_tensor): The output data.
dO (cudnn_tensor): The output loss gradient.
stats (cudnn_tensor): The softmax statistics from the forward pass.
attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None.
bias (Optional[cudnn_tensor]): The bias data for attention. Default is None.
dBias (Optional[cudnn_tensor]): The dBias output for attention. Default is None.
use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False.
use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False.
seq_len_q (Optional[cudnn_tensor]): The sequence length of the query.
seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False.
sliding_window_length (Optional[int]): The length of sliding window. Default is None.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)],
Tuple[mask: cudnn_tensor, scale: cudnn_tensor, scale_inv: cudnn_tensor]]]):
Whether to do dropout. Default is None.
rng_dump (Optional[cudnn_tensor]): Debug tensor used to output the Philox RNG dropout mask
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.
Returns:
dQ (cudnn_tensor): The query gradient data.
dK (cudnn_tensor): The key gradient data.
dV (cudnn_tensor): The value gradient data.
This operation computes the scaled dot product attention (SDPA) in the 8-bit floating point (FP8) datatype, using the FlashAttention-2 algorithm as described in the paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation.
The FP8 datatype consists of two encodings:
FP8_E4M3
(1 sign bit, 4 exponent bits, and 3 mantissa bits)FP8_E5M2
(1 sign bit, 5 exponent bits, 2 mantissa bits).
Due to the limited numerical precision of FP8 data type, for practical use cases, users must scale values computed in FP32 format before storing them in FP8 format, and descale the values stored in FP8 format before performing computations on them. For more information, refer to the Transformer Engine FP8 Primer.
The suggested value for the scaling factor is computed as: (Max representable value in the fp8 format) / (Max absolute value seen in the tensor for the previous layer).
- For E4M3, the suggested scaling factor is
448.f/ prev_layer_tensor_amax
(rounded to the nearest lower power of two) - For E5M2, the suggested scaling factor is
57344.f/ prev_layer_tensor_amax
(rounded to the nearest lower power of two)
The suggested value for the descale factor is the reciprocal of the scale factor.
Since scaling and descaling are critical for convergence with FP8 datatype, users are required to pass scaling and descaling input tensors, as well as amax output tensors.
- C++ sample: samples/cpp/sdpa
The current FP8 support is a subset of the options supported in FP16 and BF16 support. We are actively working on expanding the support for FP8.
- Attention scale (
attn_scale
): Applies a scaling factor to attention scores before the softmax, such as$\frac{1}{\sqrt{\text{d}}}$ . Set to 1.0 by default. - Causal mask: Fills the upper triangular matrix of attention scores with negative infinity.
The tensors in forward operation are defined as the following:
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
Q | GPU | E4M3 or E5M2 | |
K | GPU | E4M3 or E5M2 | |
V | GPU | E4M3 or E5M2 | |
Descale Q | GPU | FP32 | |
Descale K | GPU | FP32 | |
Descale V | GPU | FP32 | |
(Bias mask) Bias Mask | GPU | E4M3 or E5M2 |
|
(Padding mask) Sequence Length Q | GPU | INT32 | |
(Padding mask) Sequence Length KV | GPU | INT32 | |
(Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | |
(Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | |
(Custom Dropout Mask) Mask | GPU | E4M3 or E5M2 |
|
(Custom Dropout Mask) Scale | GPU | FP32 | |
Descale S | GPU | FP32 | |
Scale S | GPU | FP32 |
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
O | GPU | E4M3 or E5M2 | |
Stats (training only) | GPU | FP32 | |
AMax S | GPU | FP32 | |
AMax O | GPU | FP32 |
Where,
-
$B$ is the batch size -
$H_{q}$ is the number of query heads -
$H_{k}$ is the number of key heads -
$H_{v}$ is the number of value heads -
$S_{q}$ is the sequence length of the query -
$S_{kv}$ is the sequence length of the key and value -
$D_{qk}$ is the embedding dimension per head of query and key -
$D_{v}$ is the embedding dimension per head of value
- As described in the paper GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,
- When
$H_{k}$ and$H_{v}$ is less than$H_{q}$ and factors of$H_{q}$ , this operation will perform group-query attention (GQA) computation. - When
$H_{k}$ and$H_{v}$ are both set to 1, this operation perform multi-query attention (MQA) computation.
- The dimension of the embedding dimension per head
$D_{qk}$ and$D_{v}$ must be a multiple of 8 with maximum value 128. - the stride of the embedding dimension per head
$D_{qk}$ and$D_{v}$ for all the tensors above must be 1. - this operation is only supported on GPUs with NVIDIA Hopper architecture (SM90) or newer.
// returns [o, stats, amax_s, amax_o]
std::array<std::shared_ptr<Tensor_attributes>, 4>
Graph::sdpa_fp8(std::shared_ptr<Tensor_attributes> q,
std::shared_ptr<Tensor_attributes> k,
std::shared_ptr<Tensor_attributes> v,
std::shared_ptr<Tensor_attributes> descale_q,
std::shared_ptr<Tensor_attributes> descale_k,
std::shared_ptr<Tensor_attributes> descale_v,
std::shared_ptr<Tensor_attributes> descale_s,
std::shared_ptr<Tensor_attributes> scale_s,
std::shared_ptr<Tensor_attributes> scale_o,
SDPA_fp8_attributes attributes);
The options
parameter of type SDPA_fp8_attributes
is used to control the attributes of the forward operation, as detailed below:
SDPA_fp8_attributes&
set_is_inference(bool const value);
SDPA_fp8_attributes&
set_attn_scale(std::shared_ptr<Tensor_attributes> value);
SDPA_fp8_attributes&
set_attn_scale(float const value);
SDPA_fp8_attributes&
set_causal_mask(bool const value);
SDPA_fp8_attributes&
set_bias(std::shared_ptr<Tensor_attributes> value);
SDPA_fp8_attributes&
set_padding_mask(bool const value);
SDPA_fp8_attributes&
set_seq_len_q(std::shared_ptr<Tensor_attributes> value);
SDPA_fp8_attributes&
set_seq_len_kv(std::shared_ptr<Tensor_attributes> value);
SDPA_fp8_attributes&
set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> seed,
std::shared_ptr<Tensor_attributes> offset);
SDPA_fp8_attributes&
set_dropout(std::shared_ptr<Tensor_attributes> mask,
std::shared_ptr<Tensor_attributes> scale);
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data.
v (cudnn_tensor): The value data.
descale_q (cudnn_tensor): Descale factor for query.
descale_k (cudnn_tensor): Descale factor for key.
descale_v (cudnn_tensor): Descale factor for value.
descale_s (cudnn_tensor): Descale factor for S tensor.
scale_s (cudnn_tensor): Scale factor for S tensor.
scale_o (cudnn_tensor): Scale factor for output.
is_inference (bool): Whether it is an inference step or training step.
attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.
Returns:
o (cudnn_tensor): The output data.
stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step.
amax_s (cudnn_tensor): The absolute maximum of S tensor.
amax_o (cudnn_tensor): The absolute maximum of output tensor.
This operation computes the gradients for scaled dot product attention (SDPA) 8-bit floating point (FP8) datatype, using the FlashAttention-2 algorithm as described in the paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. The user is required to pass the stats tensor from the forward operation to the backward operation as input.
- C++ sample: samples/cpp/sdpa
All the options mentioned in the forward FP8 operation, including ragged tensors and GQA/MQA, are applicable in the backward operation as well.
The tensors in backward operation are defined as the following:
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
Q | GPU | E4M3 or E5M2 | |
K | GPU | E4M3 or E5M2 | |
V | GPU | E4M3 or E5M2 | |
O | GPU | E4M3 or E5M2 | |
dO | GPU | E4M3 or E5M2 | |
Stats | GPU | FP32 | |
Descale Q | GPU | FP32 | |
Descale K | GPU | FP32 | |
Descale V | GPU | FP32 | |
Descale O | GPU | FP32 | |
Descale dO | GPU | FP32 | |
Descale S | GPU | FP32 | |
Descale dP | GPU | FP32 | |
Scale S | GPU | FP32 | |
Scale dQ | GPU | FP32 | |
Scale dK | GPU | FP32 | |
Scale dV | GPU | FP32 | |
Scale dP | GPU | FP32 |
Tensor Name | Device | Data Type | Dimensions |
---|---|---|---|
dQ | GPU | E4M3 or E5M2 | |
dK | GPU | E4M3 or E5M2 | |
dV | GPU | E4M3 or E5M2 | |
Amax dQ | GPU | FP32 | |
Amax dK | GPU | FP32 | |
Amax dV | GPU | FP32 | |
Amax dP | GPU | FP32 |
Where,
-
$B$ is the batch size -
$H_{q}$ is the number of query heads -
$H_{k}$ is the number of key heads -
$H_{v}$ is the number of value heads -
$S_{q}$ is the sequence length of the query -
$S_{kv}$ is the sequence length of the key and value -
$D_{qk}$ is the embedding dimension per head of query and key -
$D_{v}$ is the embedding dimension per head of value
All the limitations mentioned in the forward operation are applicable in the backward operation as well.
// returns [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP]
std::array<std::shared_ptr<Tensor_attributes>, 7>
Graph::sdpa_fp8_backward(std::shared_ptr<Tensor_attributes> q,
std::shared_ptr<Tensor_attributes> k,
std::shared_ptr<Tensor_attributes> v,
std::shared_ptr<Tensor_attributes> o,
std::shared_ptr<Tensor_attributes> dO,
std::shared_ptr<Tensor_attributes> Stats,
std::shared_ptr<Tensor_attributes> descale_q,
std::shared_ptr<Tensor_attributes> descale_k,
std::shared_ptr<Tensor_attributes> descale_v,
std::shared_ptr<Tensor_attributes> descale_o,
std::shared_ptr<Tensor_attributes> descale_do,
std::shared_ptr<Tensor_attributes> descale_s,
std::shared_ptr<Tensor_attributes> descale_dp,
std::shared_ptr<Tensor_attributes> scale_s,
std::shared_ptr<Tensor_attributes> scale_dq,
std::shared_ptr<Tensor_attributes> scale_dk,
std::shared_ptr<Tensor_attributes> scale_dv,
std::shared_ptr<Tensor_attributes> scale_dp,
SDPA_fp8_backward_attributes attributes);
The options
parameter of type SDPA_fp8_backward_attributes
is used to control the attributes of the forward operation, as detailed below:
SDPA_fp8_backward_attributes&
set_attn_scale(std::shared_ptr<Tensor_attributes> value);
SDPA_fp8_backward_attributes&
set_attn_scale(float const value);
SDPA_fp8_backward_attributes&
set_causal_mask(bool const value);
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data.
v (cudnn_tensor): The value data.
o (cudnn_tensor): The output data.
dO (cudnn_tensor): The output gradient data.
stats (cudnn_tensor): The softmax statistics in case the operation is in a training step.
descale_q (cudnn_tensor): Descale factor for query.
descale_k (cudnn_tensor): Descale factor for key.
descale_v (cudnn_tensor): Descale factor for value.
descale_o (cudnn_tensor): Descale factor for output.
descale_dO (cudnn_tensor): Descale factor for output gradient.
descale_s (cudnn_tensor): Descale factor for S tensor.
descale_dP (cudnn_tensor): Descale factor for P gradient tensor.
scale_s (cudnn_tensor): Scale factor for S tensor.
scale_dQ (cudnn_tensor): Scale factor for query gradient.
scale_dK (cudnn_tensor): Scale factor for key gradient.
scale_dV (cudnn_tensor): Scale factor for value gradient.
scale_dP (cudnn_tensor): Scale factor for dP gradient.
attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.
Returns:
dQ (cudnn_tensor): The query gradient data.
dK (cudnn_tensor): The key gradient data.
dV (cudnn_tensor): The value gradient data.
amax_dQ (cudnn_tensor): The absolute maximum of query gradient tensor.
amax_dK (cudnn_tensor): The absolute maximum of key gradient tensor.
amax_dV (cudnn_tensor): The absolute maximum of value gradient tensor.
amax_dP (cudnn_tensor): The absolute maximum of dP tensor.
cuDNN API expresses the layout of
For example, let
Below we will go through the standard usage of the attention tensors and how they can be expressed in cuDNN.
Using the notation below:
-
Case 1:
$Q$ ,$K$ ,$V$ ,$O$ are tensors in dense non-overlapping memory
This is the basic case where the user can specify dims and strides for each of$Q$ ,$K$ ,$V$ ,$O$ in any stride order. The only limitation is that the stride of the last dimension, embedding dimension per head$D_{qk}$ and$D_v$ , be 1.
For instance for$Q$ with dimensions =$[B, H_q, S_q, D_{qk}]$ , cuDNN support includes (but is not limited to):- stride =
$[S_q \times H_q \times D_{qk}, D_{qk}, H_q \times D_{qk}, 1]$ aka. BSHD layout - stride =
$[H_q \times D_{qk}, D_{qk}, B \times H_q \times D_{qk}, 1]$ aka. SBHD layout
- stride =
-
Case 2:
$Q$ ,$K$ ,$V$ are are tensors in dense interleaved layout
In some cases, users may need to interleave$Q$ ,$K$ ,$V$ tensors together to simplify the matrix multiplication preceding the scaled-dot-product operation. For instance, users can allocate a single tensor of size =$3 \times B \times H \times S \times D$ , specify the$Q$ ,$K$ ,$V$ dimensions =$[B, H, S, D]$ , and cuDNN support includes (but is not limited to):- stride =
$[S \times 3 \times H \times D, D, 3 \times H \times D, 1]$ aka. BS3HD
with$QKV$ variant pack pointers offset as
$Q_{ptr}$ =$Storage_{ptr}$
$K_{ptr}$ =$Storage_{ptr} + 1 \times H \times D$
$V_{ptr}$ =$Storage_{ptr} + 2 \times H \times D$ - stride =
$[H \times 3 \times D, 3 \times D, B \times H \times 3 \times D, 1]$ aka. SBH3D
with$QKV$ variant pack pointers offset as
$Q_{ptr}$ =$Storage_{ptr}$
$K_{ptr}$ =$Storage_{ptr} + 1 \times D$
$V_{ptr}$ =$Storage_{ptr} + 2 \times D$
- stride =
-
Case 3:
$Q$ ,$K$ ,$V$ are are tensors where not all tokens are valid
Consider Q tensor with two batches ($B$ = 2) of sequences of different lengths ["aa", "bbb"]. Let maximum sequence length$S$ = 8, and number of heads$H = 1$ . In this case, users should indicate the actual sequence lengths for each batch using the sequence length tensorseq_len = [2, 3]
, and pass it to the SDPA node usingset_seq_len_q()
andset_seq_len_kv()
. Note that every element in the sequence length tensor should always be smaller than the maximum sequence length$S$ .
cuDNN layout support for variable sequence length includes (but is not limited to):-
Fully padded layout
Q[b=0] = aa000000
Q[b=1] = bbb00000
dimension =$[B=2, H=1, S=8, D=64]$
stride =$[SHD=512, D=64, HD=64, 1]$
cuDNN reads the data based on the strides. -
Fully packed layout aka. THD, where T = sum(seq_len)
Q = aabbb000
dimension =$[B=2, H=1, S=8, D=64]$
stride =$[SHD=512, D=64, HD=64, 1]$
The strides remain the same but they are incorrect as the second batch begins at 64*2. Therefore, users must set ragged_offset tensor using<tensor>.set_ragged_offset(<ragged_offset_tensor>)
api, which is a$B + 1$ sized integer tensor telling where each batch begins. The b+1 element is where the last batch ends. For this case, ragged_offset should be[0, 2 * H * D, (2+3) * H * D] = [0, 128, 320]
-
Valid tokens in a batch are packed together
Q = aa00bbb0
For this case, ragged offset to[0, 4 * H * D, (4+3) * H * D] = [0, 256, 448]
-
Valid tokens are not packed together
Q = a0abbb00bb000000
Ragged offset is insufficient to represent this. This case is NOT supported.
-