diff --git a/CMakeLists.txt b/CMakeLists.txt index 2507ccf5..c1f6e93d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.17) -project(cudnn_frontend VERSION 1.6.1) +project(cudnn_frontend VERSION 1.7.0) option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) diff --git a/README.FE.1.0.md b/README.FE.1.0.md index badbdc01..a7e42253 100644 --- a/README.FE.1.0.md +++ b/README.FE.1.0.md @@ -51,8 +51,8 @@ FE v1.0 API follows a functional style of building a graph. Operations take in i | [Scale dot product attention](docs/operations/Attention.md) | sdpa
SDPA_attributes | sdpa | | [Scale dot product attention backward](docs/operations/Attention.md) | sdpa_backward
SDPA_backward_attributes | sdpa_backward | | [Scale dot product attention FP8](docs/operations/Attention.md) | sdpa_fp8
SDPA_fp8_attributes | sdpa_fp8 | -| [Scale dot product attention backward FP8](docs/operations/Attention.md) | sdpa_fp8_backward
SDPA_fp8_backward_attributes | sdpa_fp8_backward -| Slice | slice
Slice_attributes | slice | +| [Scale dot product attention backward FP8](docs/operations/Attention.md) | sdpa_fp8_backward
SDPA_fp8_backward_attributes | sdpa_fp8_backward | +| [Slice](docs/operations/Slice.md) | slice
Slice_attributes | slice | ### Creating the Graph Instantiate an object of class `cudnn_frontend::graph::Graph` which will house tensors and operations. @@ -156,6 +156,7 @@ cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::select_behavior_note cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_numeric_notes(std::vector const&); cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_behavior_notes(std::vector const&); cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_workspace_greater_than(int64_t const workspace); +cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_shared_mem_greater_than(int64_t const shared_memory); ``` ### Autotuning @@ -205,6 +206,23 @@ Get workspace to run autotune on all plans. `get_autotune_workspace_size() const` + +### Serialization + +Frontend v1.0 API provides two flavors of serialization. One is to checkpoint after the initial graph specification (before calling validate) and other after building the execution plan (to save on plan creation). + +`void serialize(json &j) const` +`void deserialize(const json &j)` +The above two APIs are meant to capture the user specified input tensors and nodes into the graph. This can be used to generate the log (for debugging) or to visualize the graph being created. + +`error_t serialize(std::vector &data) const` +`error_t deserialize(cudnnHandle_t handle, std::vector const &data)` + +A fully built graph can be serialized into a binary blob of data with the above two APIs. +Note: +1. Not all engine configs support serialization. +2. It is the users responsibility to make sure the UIDs of tensor being passed to the variant pack remain consistent before and after serialization. + ### Error handling C++ API returns a error object which has a error code and error message. diff --git a/benchmark/Dockerfile b/benchmark/Dockerfile index 469b4e3f..171c90be 100755 --- a/benchmark/Dockerfile +++ b/benchmark/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:24.03-py3 +FROM nvcr.io/nvidia/pytorch:24.07-py3 RUN apt-get update && \ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \ @@ -9,12 +9,12 @@ RUN apt-get update && \ RUN pip uninstall -y cudnn -RUN CMAKE_BUILD_PARALLEL_LEVEL=16 pip install git+https://github.com/NVIDIA/cudnn-frontend.git -v +RUN pip install nvidia-cudnn-frontend COPY benchmark_flash_attention.py . ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH -CMD ["python", "/workspace/benchmark_flash_attention.py"] +CMD ["python", "benchmark_flash_attention.py"] WORKDIR /workspace diff --git a/benchmark/benchmark_flash_attention.py b/benchmark/benchmark_flash_attention.py index c3fe1741..72e59e94 100755 --- a/benchmark/benchmark_flash_attention.py +++ b/benchmark/benchmark_flash_attention.py @@ -186,18 +186,21 @@ def time_fwd(func, *args, **kwargs): dtype = torch.bfloat16 bs_seqlen_vals = [ - (32, 512), - (16, 1024), - (8, 2048), + # (32, 512), + # (16, 1024), + # (8, 2048), (4, 4096), (2, 8192), (1, 16384), (1, 32768), (1, 65536), + # (1, 262144), ] causal_vals = [False, True] headdim_vals = [128] -dim = 2048 +# headdim_vals = [128, 256] +# n_heads = 16, 32, 64 +n_heads = [16] dropout_p = 0.0 fields = [ @@ -215,11 +218,19 @@ def time_fwd(func, *args, **kwargs): "cudnn BF16 (TFlops/s fwd + bwd)", ] +if cudnn.backend_version() >= 90100: + fields += [ + "cudnn FP8 (TFlops/s fwd)", + "cudnn FP8 (TFlops/s bwd)", + "cudnn FP8 (TFlops/s fwd + bwd)", + ] csvwriter.writerow(fields) methods = ["Pytorch"] if cudnn is not None: methods += ["cudnn_bf16"] + if cudnn.backend_version() >= 90100: + methods += ["cudnn_fp8"] time_f = {} time_b = {} @@ -228,12 +239,24 @@ def time_fwd(func, *args, **kwargs): speed_b = {} speed_f_b = {} -for causal, headdim, bs_seqlen in itertools.product( - causal_vals, headdim_vals, bs_seqlen_vals +for causal, headdim, bs_seqlen, nheads in itertools.product( + causal_vals, headdim_vals, bs_seqlen_vals, n_heads ): batch_size, seqlen = bs_seqlen config = (causal, headdim, batch_size, seqlen) - nheads = dim // headdim + # nheads = dim // headdim + + if (seqlen >= 262144) and (nheads > 16): + continue + + if (seqlen >= 262144) and (headdim > 128): + continue + + print( + "Running bs={}, seqlen={}, d={}, h={}, causal={}".format( + batch_size, seqlen, headdim, nheads, causal + ) + ) if "Pytorch" in methods: qkv = torch.randn( @@ -266,8 +289,6 @@ def time_fwd(func, *args, **kwargs): and device == "cuda" and cudnn is not None ): - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" - is_causal = causal is_dropout = False if (abs(dropout_p - 0.0) < 1e-6) else True is_infer = False @@ -391,11 +412,13 @@ def time_fwd(func, *args, **kwargs): dK_bwd.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) dV_bwd.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) - graph_bwd.validate() - graph_bwd.build_operation_graph() - graph_bwd.create_execution_plans([cudnn.heur_mode.A]) - graph_bwd.check_support() - graph_bwd.build_plans() + # cuDNN Flash Attention doesn't support bprop for d=256 + if headdim != 256: + graph_bwd.validate() + graph_bwd.build_operation_graph() + graph_bwd.create_execution_plans([cudnn.heur_mode.A]) + graph_bwd.check_support() + graph_bwd.build_plans() variant_pack_fwd = { q_fwd: q_gpu, @@ -434,20 +457,298 @@ def time_fwd(func, *args, **kwargs): repeats=repeats, verbose=False, ) - b = time_fwd( - graph_bwd.execute, - variant_pack_bwd, + if headdim != 256: + b = time_fwd( + graph_bwd.execute, + variant_pack_bwd, + workspace, + repeats=repeats, + verbose=False, + ) + else: + b = 100000 + + time_f[config, "cudnn_bf16"] = f + time_b[config, "cudnn_bf16"] = b + + print("cudnn_fp16 done") + if "cudnn_fp8" in methods and device == "cuda" and cudnn is not None: + is_causal = causal + is_dropout = False if (abs(dropout_p - 0.0) < 1e-6) else True + is_infer = False + input_type = dtype + attn_scale = headdim ** (-0.5) + dropout_prob = dropout_p if is_dropout else 0.0 + + shape_qkvo = (batch_size, nheads, seqlen, headdim) + stride_qkv = (seqlen * 3 * nheads * headdim, headdim, 3 * nheads * headdim, 1) + stride_o = (seqlen * nheads * headdim, headdim, nheads * headdim, 1) + offset_q, offset_k, offset_v = [nheads * headdim * i for i in range(3)] + + qkv_gpu = torch.randint( + 256, + (batch_size * seqlen * 3 * nheads * headdim,), + dtype=torch.uint8, + device="cuda", + ) + q_gpu, k_gpu, v_gpu = [ + torch.as_strided(qkv_gpu, shape_qkvo, stride_qkv, storage_offset=offset) + for offset in [offset_q, offset_k, offset_v] + ] + o_gpu = torch.empty(*shape_qkvo, dtype=torch.uint8, device="cuda").as_strided( + shape_qkvo, stride_o + ) + dQ_gpu, dK_gpu, dV_gpu = [ + torch.empty_like(tensor) for tensor in [q_gpu, k_gpu, v_gpu] + ] + dO_gpu = torch.randint_like(o_gpu, 256) + + stats_gpu = ( + torch.empty( + batch_size, nheads, seqlen, 1, dtype=torch.float32, device="cuda" + ) + if not is_infer + else None + ) + + descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + + scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") + + amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") + amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") + amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") + amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") + amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") + amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") + + # cudnn graph forward + graph_fwd = cudnn.pygraph( + io_data_type=cudnn.data_type.FP8_E4M3, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_fwd = graph_fwd.tensor_like(q_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + k_fwd = graph_fwd.tensor_like(k_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + v_fwd = graph_fwd.tensor_like(v_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + + descale_q_fwd = graph_fwd.tensor_like(descale_q_gpu) + descale_k_fwd = graph_fwd.tensor_like(descale_k_gpu) + descale_v_fwd = graph_fwd.tensor_like(descale_v_gpu) + descale_s_fwd = graph_fwd.tensor_like(descale_s_gpu) + + scale_s_fwd = graph_fwd.tensor_like(scale_s_gpu) + scale_o_fwd = graph_fwd.tensor_like(scale_o_gpu) + + o_fwd, stats_fwd, amax_s_fwd, amax_o_fwd = graph_fwd.sdpa_fp8( + q=q_fwd, + k=k_fwd, + v=v_fwd, + descale_q=descale_q_fwd, + descale_k=descale_k_fwd, + descale_v=descale_v_fwd, + descale_s=descale_s_fwd, + scale_s=scale_s_fwd, + scale_o=scale_o_fwd, + is_inference=is_infer, + attn_scale=attn_scale, + use_causal_mask=is_causal, + ) + + o_fwd.set_output(True).set_dim(o_gpu.size()).set_stride( + o_gpu.stride() + ).set_data_type(cudnn.data_type.FP8_E4M3) + ( + stats_fwd.set_output(True) + .set_dim(stats_gpu.size()) + .set_stride(stats_gpu.stride()) + .set_data_type(cudnn.data_type.FLOAT) + if not is_infer + else None + ) + amax_s_fwd.set_output(True).set_dim(amax_s_gpu.size()).set_stride( + amax_s_gpu.stride() + ).set_data_type(cudnn.data_type.FLOAT) + amax_o_fwd.set_output(True).set_dim(amax_o_gpu.size()).set_stride( + amax_o_gpu.stride() + ).set_data_type(cudnn.data_type.FLOAT) + + graph_fwd.validate() + graph_fwd.build_operation_graph() + graph_fwd.create_execution_plans([cudnn.heur_mode.A]) + graph_fwd.check_support() + graph_fwd.build_plans() + + # cudnn graph backward + graph_bwd = cudnn.pygraph( + io_data_type=cudnn.data_type.FP8_E4M3, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_bwd = graph_bwd.tensor_like(q_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + k_bwd = graph_bwd.tensor_like(k_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + v_bwd = graph_bwd.tensor_like(v_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + o_bwd = graph_bwd.tensor_like(o_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + dO_bwd = graph_bwd.tensor_like(dO_gpu).set_data_type(cudnn.data_type.FP8_E4M3) + stats_bwd = graph_bwd.tensor_like(stats_gpu) + + descale_q_bwd = graph_bwd.tensor_like(descale_q_gpu) + descale_k_bwd = graph_bwd.tensor_like(descale_k_gpu) + descale_v_bwd = graph_bwd.tensor_like(descale_v_gpu) + descale_o_bwd = graph_bwd.tensor_like(descale_o_gpu) + descale_dO_bwd = graph_bwd.tensor_like(descale_dO_gpu) + descale_s_bwd = graph_bwd.tensor_like(descale_s_gpu) + descale_dP_bwd = graph_bwd.tensor_like(descale_dP_gpu) + + scale_s_bwd = graph_bwd.tensor_like(scale_s_gpu) + scale_dQ_bwd = graph_bwd.tensor_like(scale_dQ_gpu) + scale_dK_bwd = graph_bwd.tensor_like(scale_dK_gpu) + scale_dV_bwd = graph_bwd.tensor_like(scale_dV_gpu) + scale_dP_bwd = graph_bwd.tensor_like(scale_dP_gpu) + + dQ_bwd, dK_bwd, dV_bwd, amax_dQ_bwd, amax_dK_bwd, amax_dV_bwd, amax_dP_bwd = ( + graph_bwd.sdpa_fp8_backward( + q=q_bwd, + k=k_bwd, + v=v_bwd, + o=o_bwd, + dO=dO_bwd, + stats=stats_bwd, + descale_q=descale_q_bwd, + descale_k=descale_k_bwd, + descale_v=descale_v_bwd, + descale_o=descale_o_bwd, + descale_dO=descale_dO_bwd, + descale_s=descale_s_bwd, + descale_dP=descale_dP_bwd, + scale_s=scale_s_bwd, + scale_dQ=scale_dQ_bwd, + scale_dK=scale_dK_bwd, + scale_dV=scale_dV_bwd, + scale_dP=scale_dP_bwd, + attn_scale=attn_scale, + use_causal_mask=is_causal, + ) + ) + + dQ_bwd.set_output(True).set_dim(dQ_gpu.size()).set_stride( + dQ_gpu.stride() + ).set_data_type(cudnn.data_type.FP8_E4M3) + dK_bwd.set_output(True).set_dim(dK_gpu.size()).set_stride( + dK_gpu.stride() + ).set_data_type(cudnn.data_type.FP8_E4M3) + dV_bwd.set_output(True).set_dim(dV_gpu.size()).set_stride( + dV_gpu.stride() + ).set_data_type(cudnn.data_type.FP8_E4M3) + amax_dQ_bwd.set_output(True).set_dim(amax_dQ_gpu.size()).set_stride( + amax_dQ_gpu.stride() + ).set_data_type(cudnn.data_type.FLOAT) + amax_dK_bwd.set_output(True).set_dim(amax_dK_gpu.size()).set_stride( + amax_dK_gpu.stride() + ).set_data_type(cudnn.data_type.FLOAT) + amax_dV_bwd.set_output(True).set_dim(amax_dV_gpu.size()).set_stride( + amax_dV_gpu.stride() + ).set_data_type(cudnn.data_type.FLOAT) + amax_dP_bwd.set_output(True).set_dim(amax_dP_gpu.size()).set_stride( + amax_dP_gpu.stride() + ).set_data_type(cudnn.data_type.FLOAT) + + # cuDNN Flash Attention fp8 only support bprop for d=128 + if headdim == 128: + graph_bwd.validate() + graph_bwd.build_operation_graph() + graph_bwd.create_execution_plans([cudnn.heur_mode.A]) + graph_bwd.check_support() + graph_bwd.build_plans() + + variant_pack_fwd = { + q_fwd: q_gpu, + k_fwd: k_gpu, + v_fwd: v_gpu, + o_fwd: o_gpu, + stats_fwd: stats_gpu, + descale_q_fwd: descale_q_gpu, + descale_k_fwd: descale_k_gpu, + descale_v_fwd: descale_v_gpu, + descale_s_fwd: descale_s_gpu, + scale_s_fwd: scale_s_gpu, + scale_o_fwd: scale_o_gpu, + amax_s_fwd: amax_s_gpu, + amax_o_fwd: amax_o_gpu, + } + + variant_pack_bwd = { + q_bwd: q_gpu, + k_bwd: k_gpu, + v_bwd: v_gpu, + o_bwd: o_gpu, + dQ_bwd: dQ_gpu, + dK_bwd: dK_gpu, + dV_bwd: dV_gpu, + dO_bwd: dO_gpu, + stats_bwd: stats_gpu, + descale_q_bwd: descale_q_gpu, + descale_k_bwd: descale_k_gpu, + descale_v_bwd: descale_v_gpu, + descale_o_bwd: descale_o_gpu, + descale_s_bwd: descale_s_gpu, + descale_dP_bwd: descale_dP_gpu, + descale_dO_bwd: descale_dO_gpu, + scale_s_bwd: scale_s_gpu, + scale_dQ_bwd: scale_dQ_gpu, + scale_dK_bwd: scale_dK_gpu, + scale_dV_bwd: scale_dV_gpu, + scale_dP_bwd: scale_dP_gpu, + amax_dQ_bwd: amax_dQ_gpu, + amax_dK_bwd: amax_dK_gpu, + amax_dV_bwd: amax_dV_gpu, + amax_dP_bwd: amax_dP_gpu, + } + + workspace = torch.empty( + max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), + device="cuda", + dtype=torch.uint8, + ) + + f = time_fwd( + graph_fwd.execute, + variant_pack_fwd, workspace, repeats=repeats, verbose=False, ) + # cuDNN Flash Attention doesn't support bprop for d=256 + if headdim == 128: + b = time_fwd( + graph_bwd.execute, + variant_pack_bwd, + workspace, + repeats=repeats, + verbose=False, + ) + else: + b = 100000 - time_f[config, "cudnn_bf16"] = f - time_b[config, "cudnn_bf16"] = b + time_f[config, "cudnn_fp8"] = f + time_b[config, "cudnn_fp8"] = b row = [] row.append(str(batch_size)) - row.append(str(2048 // headdim)) + row.append(str(nheads)) row.append(str(seqlen)) row.append(str(headdim)) row.append(str(causal)) @@ -480,4 +781,6 @@ def time_fwd(func, *args, **kwargs): row.append(str(speed_f_b[config, method])) csvwriter.writerow(row) + print(row) + csvfile.close() diff --git a/benchmark/benchmark_results.csv b/benchmark/benchmark_results.csv index be9e57e3..4ad15c73 100644 --- a/benchmark/benchmark_results.csv +++ b/benchmark/benchmark_results.csv @@ -1,11 +1,11 @@ -Batch,Number of heads,Sequence length,Head dim,causal,dropout_p,pytorch BF16 (TFlops/s fwd),pytorch BF16 (TFlops/s bwd),pytorch BF16 (TFlops/s fwd + bwd),cudnn BF16 (TFlops/s fwd),cudnn BF16 (TFlops/s bwd),cudnn BF16 (TFlops/s fwd + bwd) -4,16,4096,128,False,0.0,354.2397874736458,287.35625860243357,303.74170280246364,654.153086109015,517.2529813957277,550.1484627007782 -2,16,8192,128,False,0.0,355.4034034630014,303.29112763235514,316.55274512148907,668.7044485311396,530.3655776577822,563.6834495118723 -1,16,16384,128,False,0.0,357.1488673077193,308.87083378477416,321.27920765691954,641.9087721400894,538.4599964919703,564.4501969340928 -1,16,32768,128,False,0.0,349.2466644830486,312.4916673813486,322.17920925495895,612.5588090524702,536.2343244414764,556.0288630093239 -1,16,65536,128,False,0.0,353.12476619851515,321.90637315990637,330.2480740982632,602.1520105575843,535.4105193525595,552.920447817744 -4,16,4096,128,True,0.0,314.06003719208184,248.76414009628874,264.4746092756496,542.3175668926519,443.6033150228602,467.9392251734923 -2,16,8192,128,True,0.0,325.13744767953693,277.4434188414532,289.5800342073893,572.8628678140092,463.5408134793645,490.27251570807965 -1,16,16384,128,True,0.0,324.7708057371761,305.18030632852344,310.5321955338218,586.8151359973481,497.67958736481233,520.2583778032041 -1,16,32768,128,True,0.0,345.3025795381865,317.28192894405663,324.81276931149955,577.3478523124705,499.6635003500654,519.6405324459428 -1,16,65536,128,True,0.0,342.65454026558206,323.1646217405224,328.50319349561096,567.0220986793877,503.80652594122864,520.3824870851007 \ No newline at end of file +Batch,Heads,Sequence length,Head dim,Causal,Pytorch FWD (TFLOPs/s),Pytorch BWD (TFLOPs/s),Pytorch FWD+BWD (TFLOPs/s),cuDNN BF16 FWD (TFLOPs/s),cuDNN BF16 BWD (TFLOPs/s),cuDNN BF16 FWD+BWD (TFLOPs/s),cuDNN FP8 FWD (TFLOPs/s),cuDNN FP8 BWD (TFLOPs/s),cuDNN FP8 FWD+BWD (TFLOPs/s) +4,16,4096,128,False,362.03,290.47,307.86,704.41,535.06,574.53,1020.77,696.09,765.67 +2,16,8192,128,False,364.03,305.47,320.18,696.52,536.95,574.56,1042.20,715.65,786.02 +1,16,16384,128,False,360.50,317.34,328.58,687.97,546.53,580.64,1037.43,736.88,803.38 +1,16,32768,128,False,356.15,320.44,329.89,648.63,546.31,572.09,992.24,731.45,790.84 +1,16,65536,128,False,360.62,330.07,338.26,635.00,549.29,571.32,972.50,736.03,790.99 +4,16,4096,128,True,296.52,245.18,257.94,613.49,438.45,477.36,889.73,434.33,508.73 +2,16,8192,128,True,321.94,275.40,287.26,645.95,461.89,502.82,985.71,602.80,678.05 +1,16,16384,128,True,327.17,305.19,311.17,649.99,488.56,525.88,970.39,634.72,704.33 +1,16,32768,128,True,345.64,316.92,324.62,636.08,504.14,535.90,984.84,641.10,712.11 +1,16,65536,128,True,338.85,323.98,328.09,631.20,508.07,538.06,962.72,631.44,700.29 \ No newline at end of file diff --git a/benchmark/images/bprop.png b/benchmark/images/bprop.png index 2ee06cba..32a7de1f 100644 Binary files a/benchmark/images/bprop.png and b/benchmark/images/bprop.png differ diff --git a/benchmark/images/forward.png b/benchmark/images/forward.png index 4d922449..45eb64f6 100644 Binary files a/benchmark/images/forward.png and b/benchmark/images/forward.png differ diff --git a/benchmark/images/fwd_bprop.png b/benchmark/images/fwd_bprop.png index 28f66049..d3eb66e0 100644 Binary files a/benchmark/images/fwd_bprop.png and b/benchmark/images/fwd_bprop.png differ diff --git a/docs/dynamic_kernel_cache.md b/docs/dynamic_kernel_cache.md new file mode 100644 index 00000000..6235c5f0 --- /dev/null +++ b/docs/dynamic_kernel_cache.md @@ -0,0 +1,28 @@ +## Table of Contents +1. [Dynamic Shapes APIs](#Dynamic-Shapes) +2. [Kernel Cache APIs](#Kernel-Cache) + +### Dynamic Shapes +Causes other APIs (such as the kernel cache) to treat the graph as a dynamic shape graph. + +The API to achieve the above is: +```cpp +graph.set_dynamic_shape_enabled(true) +``` + +### Kernel Cache +The kernel cache significantly reduces plan build time by re-using a previously compiled kernel for a given execution plan. Kernel caching is enabled only for dynamic shape graphs. + +If a graph's kernel cache attribute is set, the kernel cache will store the kernel which was compiled for the graph's execution plan. +On future same-topology operation graphs, the kernel cache may bind the previously compiled kernel to the execution plan to avoid recompilation. + +The API to create a kernel cache is: +```cpp +auto kernel_cache = std::make_shared(); +``` + +The API to set a dynamic shape graph's kernel cache is: +```cpp +graph.set_kernel_cache(kernel_cache) +``` + diff --git a/docs/operations/Attention.md b/docs/operations/Attention.md index 9758a652..e62428d8 100644 --- a/docs/operations/Attention.md +++ b/docs/operations/Attention.md @@ -356,6 +356,13 @@ $O = SV$ | Descale Q | GPU | FP32 | $(1, 1, 1, 1)$ | | Descale K | GPU | FP32 | $(1, 1, 1, 1)$ | | Descale V | GPU | FP32 | $(1, 1, 1, 1)$ | +| (Bias mask) Bias Mask | GPU | E4M3 or E5M2 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | +| (Padding mask) Sequence Length Q | GPU | INT32 | $(B, 1, 1, 1)$ | +| (Padding mask) Sequence Length KV | GPU | INT32 | $(B, 1, 1, 1)$ | +| (Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | +| (Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | +| (Custom Dropout Mask) Mask | GPU | E4M3 or E5M2 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | +| (Custom Dropout Mask) Scale | GPU | FP32 | $(1, 1, 1, 1)$ | | Descale S | GPU | FP32 | $(1, 1, 1, 1)$ | | Scale S | GPU | FP32 | $(1, 1, 1, 1)$ | @@ -421,8 +428,30 @@ set_attn_scale(float const value); SDPA_fp8_attributes& set_causal_mask(bool const value); + +SDPA_fp8_attributes& +set_bias(std::shared_ptr value); + +SDPA_fp8_attributes& +set_padding_mask(bool const value); + +SDPA_fp8_attributes& +set_seq_len_q(std::shared_ptr value); + +SDPA_fp8_attributes& +set_seq_len_kv(std::shared_ptr value); + +SDPA_fp8_attributes& +set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset); + +SDPA_fp8_attributes& +set_dropout(std::shared_ptr mask, + std::shared_ptr scale); ``` + #### Python API: ``` Args: diff --git a/docs/operations/Convolutions.md b/docs/operations/Convolutions.md index 51e99d28..54bd0dc1 100644 --- a/docs/operations/Convolutions.md +++ b/docs/operations/Convolutions.md @@ -1,8 +1,8 @@ ## Table of Contents -1. [Fprop](#Convolution Fprop) -2. [Dgrad](#Convolution Dgrad) -3. [Wgrad](#Convolution Wgrad) +1. [Fprop](#Convolution-Fprop) +2. [Dgrad](#Convolution-Dgrad) +3. [Wgrad](#Convolution-Wgrad) ### Convolution Fprop Convolution fprop computes: @@ -15,7 +15,7 @@ std::shared_ptr conv_fprop(std::shared_ptr Conv_fprop_attributes); ``` -Conv_fprop attributes is a lightweight structure with setters: +Conv_fprop_attributes is a lightweight structure with setters: ``` Conv_fprop_attributes& set_padding(std::vector) @@ -31,6 +31,9 @@ set_name(std::string const&) Conv_fprop_attributes& set_compute_data_type(DataType_t value) + +Conv_fprop_attributes& +set_convolution_mode(ConvolutionMode_t mode_) ``` Python API: @@ -53,7 +56,7 @@ std::shared_ptr conv_dgrad(std::shared_ptr Conv_dgrad_attributes); ``` -Conv_dgrad attributes is a lightweight structure with setters: +Conv_dgrad_attributes is a lightweight structure with setters: ``` Conv_dgrad_attributes& set_padding(std::vector) @@ -69,6 +72,9 @@ set_name(std::string const&) Conv_dgrad_attributes& set_compute_data_type(DataType_t value) + +Conv_dgrad_attributes& +set_convolution_mode(ConvolutionMode_t mode_) ``` Python API: @@ -91,7 +97,7 @@ std::shared_ptr conv_wgrad(std::shared_ptr Conv_wgrad_attributes); ``` -Conv_wgrad attributes is a lightweight structure with setters: +Conv_wgrad_attributes is a lightweight structure with setters: ``` Conv_wgrad_attributes& set_padding(std::vector) @@ -107,6 +113,9 @@ set_name(std::string const&) Conv_wgrad_attributes& set_compute_data_type(DataType_t value) + +Conv_wgrad_attributes& +set_convolution_mode(ConvolutionMode_t mode_) ``` Python API: diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h index 2d078f1b..e3f1ec87 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -122,6 +122,7 @@ #include "cudnn_frontend/graph_interface.h" #include "cudnn_frontend/utils/serialize.h" +#include "cudnn_frontend/backend/kernel_cache.h" #include "cudnn_frontend_version.h" diff --git a/include/cudnn_frontend/backend/backend_descriptor.h b/include/cudnn_frontend/backend/backend_descriptor.h index dc7ad253..47387a18 100644 --- a/include/cudnn_frontend/backend/backend_descriptor.h +++ b/include/cudnn_frontend/backend/backend_descriptor.h @@ -2,6 +2,7 @@ #include +#include "../graph_helpers.h" #include "cudnn.h" namespace cudnn_frontend::detail { @@ -64,7 +65,11 @@ class backend_descriptor { * * Destroys the `cudnnBackendDescriptor_t` object and frees the associated resources. */ - ~backend_descriptor() { detail::destroy_descriptor(desc); } + ~backend_descriptor() { + if (desc) { + detail::destroy_descriptor(desc); + } + } /** * @brief Deleted copy constructor and assignment operator. @@ -76,6 +81,27 @@ class backend_descriptor { backend_descriptor& operator=(backend_descriptor const&) = delete; + /** + * @brief Initializes a `backend_descriptor` object. + * + * @param type The type of the backend descriptor to create. + */ + error_t + initialize(cudnnBackendDescriptorType_t type) { + CHECK_CUDNN_ERROR(detail::create_descriptor(type, &desc)); + return {error_code_t::OK, ""}; + } + + /** + * @brief Finalizes a `backend_descriptor` object. + * + */ + error_t + finalize() { + CHECK_CUDNN_ERROR(detail::finalize(desc)); + return {error_code_t::OK, ""}; + } + /** * @brief Accessor for the underlying `cudnnBackendDescriptor_t` object. * @@ -96,6 +122,14 @@ class backend_descriptor { return status; } + protected: + /** + * @brief Constructs a default `backend_descriptor` object, but without initializing descriptor + * + * Used to return an error code to user for incorrect cuDNN version + */ + backend_descriptor() {} + private: cudnnBackendDescriptor_t desc = nullptr; //!< Raw pointer to the backend descriptor. cudnnStatus_t status = CUDNN_STATUS_SUCCESS; //!< Status of the descriptor creation operation. diff --git a/include/cudnn_frontend/backend/kernel_cache.h b/include/cudnn_frontend/backend/kernel_cache.h new file mode 100644 index 00000000..d500d5f1 --- /dev/null +++ b/include/cudnn_frontend/backend/kernel_cache.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "../graph_helpers.h" +#include "backend_descriptor.h" + +namespace cudnn_frontend { +namespace graph { +class Graph; +} // namespace graph +/// +/// KernelCache Class +/// Wraps the kernel_cache backend descriptor +/// Wraps backend utility functions for user's convenience +/// Backend accessor functions: size() +/// Contains internal utilities for kernel cache finalization and operation graph attributes +/// +class KernelCache : public detail::backend_descriptor { + public: + friend class graph::Graph; + // Uses the default backend constructor so that we can check for initialization error during build() + KernelCache() : backend_descriptor() {} + + std::string + describe() const { + std::stringstream ss; + ss << "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR : " << std::endl; + return ss.str(); + } + + bool + is_finalized() { + return finalized; + } + + // Used to check kernel cache status (particularly after initialization) + error_t + status() { + if (get_status() != CUDNN_STATUS_SUCCESS) { + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR: Check CUDNN_VERSION >= 9.4"}; + } + return {error_code_t::OK, ""}; + } + + private: + // Responsible for initializing, setting operation graph attribute, and finalizing kernel cache + // Check for both compile-time and runtime cuDNN version + error_t + build(cudnnBackendDescriptor_t op_graph) { +#if (CUDNN_VERSION >= 90400) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90400, + error_code_t::GRAPH_NOT_SUPPORTED, + "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR is only available starting 9.4."); + CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR)); +#if (CUDNN_VERSION >= 90500) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500, + error_code_t::GRAPH_NOT_SUPPORTED, + "CUDNN_ATTR_KERNEL_CACHE_OPERATION_GRAPH is only available starting 9.5."); + CHECK_CUDNN_ERROR(detail::set_attribute( + get_ptr(), CUDNN_ATTR_KERNEL_CACHE_OPERATION_GRAPH, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); +#else + (void)op_graph; +#endif + CHECK_CUDNN_FRONTEND_ERROR(finalize()); + finalized = true; + return {error_code_t::OK, ""}; +#else + (void)op_graph; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR is only available starting 9.4."}; +#endif + } + + bool finalized = false; +}; +} // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h index 72930240..fc61de6c 100644 --- a/include/cudnn_frontend/cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_interface.h @@ -40,6 +40,9 @@ class ICudnn { graph::Execution_plan_list plans; + bool is_dynamic_shape_enabled = false; + std::shared_ptr kernel_cache = nullptr; + void assign_uid(graph::Tensor_attributes* const tensor, int64_t& potential_uid, @@ -120,8 +123,9 @@ class ICudnn { } auto&& cudnn_operation_graph_builder = cudnn_frontend::OperationGraphBuilder(); - cudnn_operation_graph_builder.setHandle(handle).setOperationGraph(cudnn_operations.size(), - cudnn_operations.data()); + cudnn_operation_graph_builder.setHandle(handle) + .setOperationGraph(cudnn_operations.size(), cudnn_operations.data()) + .setIsDynamicShapeEnabled(is_dynamic_shape_enabled); for (auto& op : raw_operations) { cudnn_operation_graph_builder.addOperation(op); } diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index 335bcc1a..1d84850b 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -26,6 +26,7 @@ #include "plans.h" #include "graph_helpers.h" +#include "backend/kernel_cache.h" namespace cudnn_frontend::graph { @@ -35,6 +36,7 @@ class Graph : public INode { std::unordered_set used_uids; int64_t fe_workspace_size = 0; + std::unordered_set> deserialized_tensor_properties; std::unordered_map deserialized_pass_by_value; std::unordered_map>> deserialized_workspace_modifications; @@ -69,6 +71,10 @@ class Graph : public INode { error_t pre_validate_node() const override final { + RETURN_CUDNN_FRONTEND_ERROR_IF( + (is_dynamic_shape_enabled || kernel_cache != nullptr) && detail::get_backend_version() < 90400, + error_code_t::GRAPH_NOT_SUPPORTED, + "Dynamic shapes or kernel caching enabled, but cuDNN version < 9.4!"); return {error_code_t::OK, ""}; } @@ -160,18 +166,6 @@ class Graph : public INode { return {error_code_t::OK, ""}; } - int64_t - get_cudnn_workspace_size(int64_t plan_index) const { - int64_t cudnn_workspace_size = 0; - - auto status = get_cudnn_workspace_size_node(plan_index, cudnn_workspace_size); - if (status.is_bad()) { - CUDNN_FE_LOG_LABEL_ENDL("ERROR: Querying workspace failed."); - } - - return cudnn_workspace_size; - } - int64_t get_max_cudnn_workspace_size() const { return get_max_cudnn_workspace_size_node(); @@ -208,6 +202,24 @@ class Graph : public INode { return {error_code_t::OK, ""}; } + size_t + key(bool remove_shape) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + json j; + serialize(j); + if (remove_shape) { + for (auto &tensor : j["tensors"]) { + tensor["dim"].clear(); + tensor["stride"].clear(); + } + } + return std::hash{}(j); +#else + CUDNN_FRONTEND_UNUSED(remove_shape); + return 1; +#endif + } + public: Graph() : INode(detail::Context{}) {} @@ -264,25 +276,55 @@ class Graph : public INode { // The method here fuses all operations. There will be 1 operation graph in total. CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_operation_graph(handle)); + if (is_dynamic_shape_enabled && kernel_cache && !kernel_cache->is_finalized()) { + CHECK_CUDNN_FRONTEND_ERROR(kernel_cache->build(operation_graph->get_raw_desc())); + } + return {error_code_t::OK, ""}; } - int64_t - get_workspace_size() const { + error_t + get_plan_name(std::string &name) const { + return get_plan_name_at_index(plans.candidate, name); + } + + error_t + get_plan_name_at_index(int64_t plan_index, std::string &name) const { + auto ret_val = plans.get_name_at_index(plan_index, name); + CUDNN_FE_LOG_LABEL_ENDL("INFO: get_plan_name_at_index(" << plan_index << ") is " + name); + return ret_val; + } + + error_t + get_workspace_size(int64_t &cudnn_workspace_size) const { + return get_workspace_size_plan_at_index(plans.candidate, cudnn_workspace_size); + } + + error_t + get_workspace_size_plan_at_index(int64_t plan_index, int64_t &cudnn_workspace_size) const { // There are two workspaces: // - cudnn execution plan workspace // - FE node workspace (example: alibiSlope for fmha) + int64_t cudnn_ws = 0; + CHECK_CUDNN_FRONTEND_ERROR(get_cudnn_workspace_size_node(plan_index, cudnn_ws)); + cudnn_workspace_size = cudnn_ws + fe_workspace_size; + CUDNN_FE_LOG_LABEL_ENDL("INFO: get_workspace_size() is " << cudnn_workspace_size); + return {error_code_t::OK, ""}; + } + + int64_t + get_workspace_size() const { return get_workspace_size_plan_at_index(plans.candidate); } int64_t get_workspace_size_plan_at_index(int64_t plan_index) const { - // There are two workspaces: - // - cudnn execution plan workspace - // - FE node workspace (example: alibiSlope for fmha) - CUDNN_FE_LOG_LABEL_ENDL("INFO: get_workspace_size() is " - << fe_workspace_size + get_cudnn_workspace_size(plan_index)); - return fe_workspace_size + get_cudnn_workspace_size(plan_index); + int64_t cudnn_workspace = 0; + auto status = get_workspace_size_plan_at_index(plan_index, cudnn_workspace); + if (status.is_bad()) { + CUDNN_FE_LOG_LABEL_ENDL("ERROR: Querying workspace failed."); + } + return cudnn_workspace; } int64_t @@ -476,6 +518,15 @@ class Graph : public INode { #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j = json::from_ubjson(data); + if (j.contains("tensors")) { + auto tensor_map = j["tensors"].get>(); + for (const auto &tensor_info : tensor_map) { + auto tensor_attributes = std::make_shared(); + from_json(tensor_info.second, *tensor_attributes); + deserialized_tensor_properties.insert(tensor_attributes); + } + } + auto serialized_plan = j["cudnn_backend_data"]; CHECK_CUDNN_FRONTEND_ERROR(plans.build_plans(handle, serialized_plan)); @@ -509,7 +560,11 @@ class Graph : public INode { Graph & set_compute_data_type(DataType_t type); Graph & + set_dynamic_shape_enabled(bool is_enabled); + Graph & set_sm_count(int32_t type); + Graph & + set_kernel_cache(std::shared_ptr cache); Graph & set_name(std::string const &name) { @@ -517,6 +572,9 @@ class Graph : public INode { return *this; } + error_t + query_tensor_attributes_of_uid(int64_t const uid, Tensor_attributes &tensor) const; + std::shared_ptr tensor(Tensor_attributes const &tensor); @@ -828,6 +886,11 @@ class Graph : public INode { }; #endif + size_t + key() override final { + return key(is_dynamic_shape_enabled); + } + // TODO: temparorily placed in graphs class. This function needs to be a free standing function. #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB error_t @@ -1002,6 +1065,7 @@ Graph::create_execution_plans(std::vector const &mode) { plans.set_tag(operation_graph->getTag()); plans.set_engine_configs(op_graph_to_configs); + plans.set_kernel_cache(kernel_cache); CUDNN_FE_LOG_LABEL_ENDL("INFO: Querying engine config properties."); CHECK_CUDNN_FRONTEND_ERROR(plans.query_properties()); @@ -1052,6 +1116,18 @@ Graph::set_compute_data_type(DataType_t const type) { return *this; } +inline Graph & +Graph::set_dynamic_shape_enabled(bool is_enabled) { + is_dynamic_shape_enabled = is_enabled; + return *this; +} + +inline Graph & +Graph::set_kernel_cache(std::shared_ptr cache) { + kernel_cache = cache; + return *this; +} + inline Graph & Graph::set_sm_count(int32_t count) { context.set_target_sm_count(count); @@ -1065,6 +1141,32 @@ Graph::tensor(Tensor_attributes const &tensor) { return tensor_ptr; } +inline error_t +Graph::query_tensor_attributes_of_uid(int64_t const uid, Tensor_attributes &tensor) const { + for (auto const &o_tensor : full_graph_outputs) { + if (uid == o_tensor->get_uid()) { + tensor = *o_tensor; + return {error_code_t::OK, ""}; + } + } + + for (auto const &i_tensor : full_graph_inputs) { + if (uid == i_tensor->get_uid()) { + tensor = *i_tensor; + return {error_code_t::OK, ""}; + } + } + + for (auto const &d_tensor : deserialized_tensor_properties) { + if (uid == d_tensor->get_uid()) { + tensor = *d_tensor; + return {error_code_t::OK, ""}; + } + } + + return {error_code_t::INVALID_VALUE, "No matching tensor for this UID"}; +} + // tensor_like is meant to create "useable" copies of a tensor. // By usable, it means not copying over the uids, as uids are FE-level(internal) detail. // It also means not copying over names, which are user-level(external) detail. But user is given option to provide a diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index 46a3a56c..748681b3 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -156,6 +156,15 @@ class Tensor_attributes { return *this; } + int64_t + get_volume() const { + int64_t volume = 1ul; + for (int64_t d : dim) { + volume *= d; + } + return volume; + } + std::vector get_stride() const { return stride; @@ -712,12 +721,30 @@ class Matmul_fp8_attributes : public Attributes { double padding_value = 0.0; public: - enum class input_names { Descale_A, Descale_B, A, B, Scale_C }; + enum class input_names { Descale_A, Descale_B, A, B, M_override, N_override, K_override, Scale_C }; std::unordered_map> inputs; enum class output_names { C, Amax_C }; std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_fp8_attributes, name, compute_data_type, inputs, outputs) + Matmul_fp8_attributes& + set_m_override(std::shared_ptr const& value) { + inputs[input_names::M_override] = value; + return *this; + } + + Matmul_fp8_attributes& + set_n_override(std::shared_ptr const& value) { + inputs[input_names::N_override] = value; + return *this; + } + + Matmul_fp8_attributes& + set_k_override(std::shared_ptr const& value) { + inputs[input_names::K_override] = value; + return *this; + } + Matmul_fp8_attributes& set_padding(double const padding_val) { padding_value = padding_val; @@ -1509,7 +1536,9 @@ class SDPA_fp8_attributes : public Attributes { friend class Graph; std::optional is_inference; - bool causal_mask = false; + bool padding_mask = false; + bool causal_mask = false; + std::optional dropout_probability; std::optional attn_scale_value; public: @@ -1518,6 +1547,14 @@ class SDPA_fp8_attributes : public Attributes { K, V, Attn_scale, + Bias, + SEQ_LEN_Q, + SEQ_LEN_KV, + Seed, + Offset, + Dropout_mask, + Dropout_scale, + Descale_Q, Descale_K, Descale_V, @@ -1535,7 +1572,9 @@ class SDPA_fp8_attributes : public Attributes { inputs, outputs, is_inference, + padding_mask, causal_mask, + dropout_probability, attn_scale_value) SDPA_fp8_attributes& @@ -1556,11 +1595,52 @@ class SDPA_fp8_attributes : public Attributes { return *this; } + SDPA_fp8_attributes& + set_bias(std::shared_ptr value) { + inputs[SDPA_fp8_attributes::input_names::Bias] = value; + return *this; + } + + SDPA_fp8_attributes& + set_padding_mask(bool const value) { + padding_mask = value; + return *this; + } + + SDPA_fp8_attributes& + set_seq_len_q(std::shared_ptr value) { + inputs[SDPA_fp8_attributes::input_names::SEQ_LEN_Q] = value; + return *this; + } + + SDPA_fp8_attributes& + set_seq_len_kv(std::shared_ptr value) { + inputs[SDPA_fp8_attributes::input_names::SEQ_LEN_KV] = value; + return *this; + } + SDPA_fp8_attributes& set_causal_mask(bool const value) { causal_mask = value; return *this; } + + SDPA_fp8_attributes& + set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset) { + dropout_probability = probability; + inputs[SDPA_fp8_attributes::input_names::Seed] = seed; + inputs[SDPA_fp8_attributes::input_names::Offset] = offset; + return *this; + } + + SDPA_fp8_attributes& + set_dropout(std::shared_ptr mask, std::shared_ptr scale) { + inputs[SDPA_fp8_attributes::input_names::Dropout_mask] = mask; + inputs[SDPA_fp8_attributes::input_names::Dropout_scale] = scale; + return *this; + } }; class SDPA_backward_attributes : public Attributes { @@ -1718,7 +1798,10 @@ class SDPA_fp8_backward_attributes : public Attributes dropout_probability; std::optional attn_scale_value; public: @@ -1730,6 +1813,15 @@ class SDPA_fp8_backward_attributes : public Attributes value) { + inputs[SDPA_fp8_backward_attributes::input_names::Bias] = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_padding_mask(bool const value) { + padding_mask = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_seq_len_q(std::shared_ptr value) { + inputs[SDPA_fp8_backward_attributes::input_names::SEQ_LEN_Q] = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_seq_len_kv(std::shared_ptr value) { + inputs[SDPA_fp8_backward_attributes::input_names::SEQ_LEN_KV] = value; + return *this; + } + SDPA_fp8_backward_attributes& set_causal_mask(bool const value) { causal_mask = value; return *this; } + + SDPA_fp8_backward_attributes& + set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset) { + dropout_probability = probability; + inputs[SDPA_fp8_backward_attributes::input_names::Seed] = seed; + inputs[SDPA_fp8_backward_attributes::input_names::Offset] = offset; + return *this; + } + + SDPA_fp8_backward_attributes& + set_dropout(std::shared_ptr mask, + std::shared_ptr scale, + std::shared_ptr scale_inv) { + inputs[SDPA_fp8_backward_attributes::input_names::Dropout_mask] = mask; + inputs[SDPA_fp8_backward_attributes::input_names::Dropout_scale] = scale; + inputs[SDPA_fp8_backward_attributes::input_names::Dropout_scale_inv] = scale_inv; + return *this; + } }; using Scaled_dot_product_flash_attention_attributes [[deprecated]] = SDPA_attributes; diff --git a/include/cudnn_frontend/node/conv_dgrad.h b/include/cudnn_frontend/node/conv_dgrad.h index 780e5588..2e256899 100644 --- a/include/cudnn_frontend/node/conv_dgrad.h +++ b/include/cudnn_frontend/node/conv_dgrad.h @@ -52,6 +52,10 @@ class DgradNode : public NodeCRTP { auto const dy_tensor_dim = DY->get_dim(); auto dx_tensor_dim = DX->get_dim(); + RETURN_CUDNN_FRONTEND_ERROR_IF(DX->get_dim().empty(), + error_code_t::ATTRIBUTE_NOT_SET, + "For dgrad node, output dimension inferencing is not possible."); + // No dim inferencing as inverse mapping from DY, W to DX is not unique. // Only infer strides if user did not set them if (DX->get_stride().empty()) { diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index bb1f52ae..e2ef388d 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" @@ -154,7 +156,10 @@ class SDPANode : public NodeCRTP { RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && attributes.sliding_window_length.value() < 0, error_code_t::INVALID_VALUE, "Sliding window length should be greater than or equals to zero when set."); - + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Sliding window attention is only supported with s_q <= s_kv."); RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (attributes.padding_mask || !attributes.causal_mask || is_dropout || is_bias || is_ragged), error_code_t::GRAPH_NOT_SUPPORTED, @@ -823,6 +828,9 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::INVALID_VALUE, "Sliding window length should be greater than or equals to zero when set."); + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Sliding window attention is only supported with s_q <= s_kv."); RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (attributes.padding_mask || !attributes.causal_mask || is_dropout || is_bias || is_ragged), error_code_t::GRAPH_NOT_SUPPORTED, @@ -858,6 +866,14 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "For cuDNN version below 9.2.0, sliding window attention is not supported"); + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_bias && attributes.padding_mask, + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.5.0, dBias with variable sequence lengths is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_bias && ((s_q % 64 != 0) || (s_kv % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.5.0, dBias not support s_q/s_kv which aren't multiple of 64"); + // validate that datatype is set for the graph RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, error_code_t::ATTRIBUTE_NOT_SET, @@ -979,10 +995,10 @@ class SDPABackwardNode : public NodeCRTP { // allow setting the upper limit with envvars char* env_dp_workspace_limit_char = std::getenv("CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"); if (env_dp_workspace_limit_char) { - try { - std::string env_dp_workspace_limit_str(env_dp_workspace_limit_char); - max_dp_workspace_bytes = static_cast(std::stoll(env_dp_workspace_limit_str)); - } catch (...) { + char* end_ptr = nullptr; + max_dp_workspace_bytes = std::strtoll(env_dp_workspace_limit_char, &end_ptr, 10); + + if (*end_ptr != '\0') { RETURN_CUDNN_FRONTEND_ERROR_IF(true, error_code_t::ATTRIBUTE_NOT_SET, "Invalid argument for CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT " diff --git a/include/cudnn_frontend/node/sdpa_fp8.h b/include/cudnn_frontend/node/sdpa_fp8.h index ec340fd6..7dc70e4e 100644 --- a/include/cudnn_frontend/node/sdpa_fp8.h +++ b/include/cudnn_frontend/node/sdpa_fp8.h @@ -17,6 +17,8 @@ class SDPAFP8Node : public NodeCRTP { using input_names = SDPA_fp8_attributes::input_names; using output_names = SDPA_fp8_attributes::output_names; + std::shared_ptr rng_output; + public: SDPA_fp8_attributes attributes; @@ -81,20 +83,85 @@ class SDPAFP8Node : public NodeCRTP { #undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE + // validate backend limitations for the operation + // int64_t s_q = attributes.inputs.at(input_names::Q)->get_dim()[2]; + // int64_t s_kv = attributes.inputs.at(input_names::K)->get_dim()[2]; + int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; + int64_t h_k = attributes.inputs.at(input_names::K)->get_dim()[1]; + int64_t h_v = attributes.inputs.at(input_names::V)->get_dim()[1]; int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; - if (detail::get_backend_version() >= 90101) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - (d_qk > 256) || (d_qk % 8 != 0) || (d_v > 256) || (d_v % 8 != 0), - error_code_t::GRAPH_NOT_SUPPORTED, - "Num hidden_dim shoud be less than 256 and hidden_dim should be multiple of 8"); + // bool const is_ragged = attributes.inputs.at(input_names::Q)->get_ragged_offset() || + // attributes.inputs.at(input_names::K)->get_ragged_offset() || + // attributes.inputs.at(input_names::V)->get_ragged_offset() || + // attributes.outputs.at(output_names::O)->get_ragged_offset(); + + auto const& bias_mask = attributes.inputs.find(input_names::Bias); + bool const is_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); + + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + // bool const is_dropout = attributes.dropout_probability.has_value() || is_dropout_custom; + + // validation TODO: + // - validate stats has valid dims + + // validate basic dimension requirements + if (prop.major >= 10) { + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 16 != 0) || (d_v > 128) || (d_v % 16 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim shoud be less than 128 and hidden_dim should be multiple of 16"); } else { - RETURN_CUDNN_FRONTEND_ERROR_IF( - (d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0), - error_code_t::GRAPH_NOT_SUPPORTED, - "Num hidden_dim shoud be less than 128 and hidden_dim should be multiple of 8"); + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 256) || (d_qk % 16 != 0) || (d_v > 256) || (d_v % 16 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim shoud be less than 256 and hidden_dim should be multiple of 16"); } + RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "For group-query attention, number of heads for key and query must be a factor " + "of number of heads for query"); + + // validate options for attn_scale + auto const& attn_scale = attributes.inputs.find(input_names::Attn_scale); + bool const has_attn_scale = (attn_scale != attributes.inputs.end()) && (attn_scale->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(has_attn_scale && attributes.attn_scale_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "attn_scale with tensor and value cannot be set at the same time."); + + // validate options for bias mask + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_data_type() == DataType_t::BOOLEAN), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask data type cannot be boolean"); + + // validate options for padding mask + auto const& seq_len_q = attributes.inputs.find(input_names::SEQ_LEN_Q); + bool const has_seq_len_q = (seq_len_q != attributes.inputs.end()) && (seq_len_q->second != nullptr); + auto const& seq_len_kv = attributes.inputs.find(input_names::SEQ_LEN_KV); + bool const has_seq_len_kv = (seq_len_kv != attributes.inputs.end()) && (seq_len_kv->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.padding_mask && (!has_seq_len_q || !has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "Padding mask requires seq_len_q and seq_len_kv to be set."); + RETURN_CUDNN_FRONTEND_ERROR_IF((!attributes.padding_mask) && (has_seq_len_q || has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + + // validate options for dropout mask + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.dropout_probability.has_value() && is_dropout_custom, + error_code_t::ATTRIBUTE_NOT_SET, + "Using both, custom dropout mask and internal-mask generation using dropout probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + // validate that datatype is set for the graph + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Intermediate tensor data type needs to be set as internal tensors require it."); + return {error_code_t::OK, ""}; } @@ -126,12 +193,12 @@ class SDPAFP8Node : public NodeCRTP { attributes.fill_from_context(context); // Gather dim to fill properties of virtual tensors - // auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); - // auto b = q_dim[0]; - // auto h = q_dim[1]; - // auto s_q = q_dim[2]; - // auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); - // auto s_kv = k_dim[2]; + auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = q_dim[0]; + auto h = q_dim[1]; + auto s_q = q_dim[2]; + auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); + auto s_kv = k_dim[2]; // cuDNN frontend API attention requires Q, K, V where // Q = {b, h_q, s_q, d_qk} @@ -157,7 +224,14 @@ class SDPAFP8Node : public NodeCRTP { auto mul_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL); //// Q * K - auto bmm1_attributes = Matmul_attributes().set_name("bmm1").set_padding(0.0); + auto bmm1_attributes = Matmul_attributes() + .set_name("bmm1") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV]); + + if (attributes.padding_mask) { + bmm1_attributes.set_padding(0.0); + } last_output = matmul(attributes.inputs[input_names::Q], attributes.inputs[input_names::K], bmm1_attributes); //// Optional Attn scale @@ -168,11 +242,11 @@ class SDPAFP8Node : public NodeCRTP { } // If attn scale present, add a pointwise mul node - if (attributes.inputs[input_names::Attn_scale]) { + if (auto attn_scale_it = attributes.inputs.find(input_names::Attn_scale); + attn_scale_it != attributes.inputs.end()) { mul_attributes.set_name("attn_scale"); - auto const& attn_scale_output = - pointwise(last_output, attributes.inputs[input_names::Attn_scale], mul_attributes); - last_output = attn_scale_output; + auto const& attn_scale_output = pointwise(last_output, attn_scale_it->second, mul_attributes); + last_output = attn_scale_output; } //// Descales @@ -184,6 +258,64 @@ class SDPAFP8Node : public NodeCRTP { mul_attributes.set_name("descale_k"); last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_K), mul_attributes); + // Optional bias + if (auto bias_it = attributes.inputs.find(input_names::Bias); bias_it != attributes.inputs.end()) { + auto add_attributes = Pointwise_attributes().set_name("bias").set_mode(PointwiseMode_t::ADD); + auto const& bias_output = pointwise(last_output, bias_it->second, add_attributes); + last_output = bias_output; + } + + if (attributes.padding_mask) { + auto row_index_attributes = Pointwise_attributes() + .set_name("gen_row_index") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32); + auto const& row_index_output = pointwise(last_output, row_index_attributes); + row_index_output->set_data_type(DataType_t::INT32); + + auto col_index_attributes = Pointwise_attributes() + .set_name("gen_col_index") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32); + auto const& col_index_output = pointwise(last_output, col_index_attributes); + col_index_output->set_data_type(DataType_t::INT32); + + auto row_less_seq_q_attributes = Pointwise_attributes() + .set_name("row_less_seq_q") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::INT32); + auto const& row_less_seq_q_output = + pointwise(row_index_output, attributes.inputs[input_names::SEQ_LEN_Q], row_less_seq_q_attributes); + row_less_seq_q_output->set_data_type(DataType_t::INT32); + + auto col_less_seq_kv_attributes = Pointwise_attributes() + .set_name("col_less_seq_kv") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::INT32); + auto const& col_less_seq_kv_output = + pointwise(col_index_output, attributes.inputs[input_names::SEQ_LEN_KV], col_less_seq_kv_attributes); + col_less_seq_kv_output->set_data_type(DataType_t::INT32); + + auto logical_and_attributes = Pointwise_attributes() + .set_name("logical_and") + .set_mode(PointwiseMode_t::LOGICAL_AND) + .set_compute_data_type(DataType_t::BOOLEAN); + auto const& logical_and_output = + pointwise(row_less_seq_q_output, col_less_seq_kv_output, logical_and_attributes); + logical_and_output->set_data_type(DataType_t::BOOLEAN); + + // Lower attributes to binary select attributes + auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); + + auto binary_select_attributes = + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); + auto const& padding_mask_output = + pointwise(last_output, negative_inf_padding, logical_and_output, binary_select_attributes); + last_output = padding_mask_output; + } + //// Optional causal masking if (attributes.causal_mask) { auto row_index_attributes = @@ -230,6 +362,61 @@ class SDPAFP8Node : public NodeCRTP { softmax(last_output, softmax_attributes, softmax_output, softmax_stats); last_output = softmax_output; + // Two cases for training: dropout present or not + bool dropout_present = false; + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + if (attributes.dropout_probability.has_value()) { + dropout_present = true; + // Special case: Skip dropout when 0.0 probability. + if (attributes.dropout_probability.value() == 0.0) { + dropout_present = false; + } + } else if (is_dropout_custom) { + dropout_present = true; + } + + if (dropout_present) { + if (is_dropout_custom) { + auto dropout_scale_attributes = + Pointwise_attributes().set_name("dropout_scale_mul").set_mode(PointwiseMode_t::MUL); + auto const& dropout_scale_output = + pointwise(last_output, attributes.inputs[input_names::Dropout_scale], dropout_scale_attributes); + + auto mask_attributes = + Pointwise_attributes().set_name("dropout_mask_mul").set_mode(PointwiseMode_t::MUL); + auto const& dropout_mask_output = + pointwise(dropout_scale_output, dropout_mask->second, mask_attributes); + last_output = dropout_mask_output; + } else { + rng_output = rng(attributes.inputs[input_names::Seed], + attributes.inputs[input_names::Offset], + Rng_attributes() + .set_name("rng") + .set_distribution(RngDistribution_t::BERNOULLI) + .set_bernoulli_probability(1.0 - attributes.dropout_probability.value())); + rng_output + // Hard coding dim and strides as rng output can no inputs to infer it from. + ->set_dim({b, h, s_q, s_kv}) + .set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}); + + auto mask_attributes = + Pointwise_attributes().set_name("dropout_mask_mul").set_mode(PointwiseMode_t::MUL); + auto const& dropout_mask_output = pointwise(last_output, rng_output, mask_attributes); + last_output = dropout_mask_output; + + std::shared_ptr dropout_scale = nullptr; + + float dropout_scale_value = (1.0f / (1.0f - attributes.dropout_probability.value())); + dropout_scale = std::make_shared(dropout_scale_value); + + auto dropout_scale_attributes = + Pointwise_attributes().set_name("dropout_scale").set_mode(PointwiseMode_t::MUL); + auto const& dropout_scale_output = pointwise(last_output, dropout_scale, dropout_scale_attributes); + last_output = dropout_scale_output; + } + } + // Amax S auto amax_attributes = Reduction_attributes().set_name("amax_s").set_mode(ReductionMode_t::AMAX); // Special non-functional-style call. Needed because output already created and provided to user. @@ -240,8 +427,15 @@ class SDPAFP8Node : public NodeCRTP { last_output = pointwise(last_output, attributes.inputs.at(input_names::Scale_S), mul_attributes); last_output->set_data_type(attributes.inputs.at(input_names::Q)->get_data_type()); + // Lower attributes to bmm2 attributes + // Requirement by cudnn backend to take in bmm2 aType as i/o type. + last_output->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + //// S * V - auto bmm2_attributes = Matmul_fp8_attributes().set_name("bmm2"); + auto bmm2_attributes = Matmul_fp8_attributes() + .set_name("bmm2") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]); // Special non-functional-style call. Needed because output already created and provided to user. matmul_fp8(last_output, attributes.inputs.at(input_names::V), diff --git a/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/include/cudnn_frontend/node/sdpa_fp8_bwd.h index 04a942b2..e82e055c 100644 --- a/include/cudnn_frontend/node/sdpa_fp8_bwd.h +++ b/include/cudnn_frontend/node/sdpa_fp8_bwd.h @@ -77,12 +77,79 @@ class SDPAFP8BackwardNode : public NodeCRTP { #undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE + // validate backend limitations for the operation + // clang-format off + // int64_t s_q = attributes.inputs.at(input_names::Q)->get_dim()[2]; + // int64_t s_kv = attributes.inputs.at(input_names::K)->get_dim()[2]; + int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; + int64_t h_k = attributes.inputs.at(input_names::K)->get_dim()[1]; + int64_t h_v = attributes.inputs.at(input_names::V)->get_dim()[1]; + int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; + int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; + + auto const& bias_mask = attributes.inputs.find(input_names::Bias); + bool const is_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); + + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + + // validation TODO: + // - validate stats has valid dims + + // validate basic dimension requirements + if(prop.major >= 10) { + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 16 != 0) || (d_v > 128) || (d_v % 16 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim shoud be less than 128 and hidden_dim should be multiple of 16"); + } + else { + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk != 128) || (d_qk % 16 != 0) || (d_v != 128) || (d_v % 16 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim shoud be equal to 128 and hidden_dim should be multiple of 16"); + } + RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); + // validate options for attn_scale auto const& attn_scale = attributes.inputs.find(input_names::Attn_scale); bool const has_attn_scale = (attn_scale != attributes.inputs.end()) && (attn_scale->second != nullptr); RETURN_CUDNN_FRONTEND_ERROR_IF(has_attn_scale && attributes.attn_scale_value.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "attn_scale with tensor and value cannot be set at the same time."); + + // validate options for bias mask + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_data_type() == DataType_t::BOOLEAN), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask data type cannot be boolean"); + + // validate options for padding mask + auto const& seq_len_q = attributes.inputs.find(input_names::SEQ_LEN_Q); + bool const has_seq_len_q = (seq_len_q != attributes.inputs.end()) && (seq_len_q->second != nullptr); + auto const& seq_len_kv = attributes.inputs.find(input_names::SEQ_LEN_KV); + bool const has_seq_len_kv = (seq_len_kv != attributes.inputs.end()) && (seq_len_kv->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.padding_mask && (!has_seq_len_q || !has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "Padding mask requires seq_len_q and seq_len_kv to be set."); + RETURN_CUDNN_FRONTEND_ERROR_IF((!attributes.padding_mask) && (has_seq_len_q || has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + + // validate options for dropout mask + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.dropout_probability.has_value() && is_dropout_custom, + error_code_t::ATTRIBUTE_NOT_SET, + "Using both, custom dropout mask and internal-mask generation using dropout probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + // validate that datatype is set for the graph + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Intermediate tensor data type needs to be set as internal tensors require it."); return {error_code_t::OK, ""}; } @@ -90,6 +157,7 @@ class SDPAFP8BackwardNode : public NodeCRTP { infer_properties_node() override final { return {error_code_t::OK, ""}; } + error_t expand_node() override final { CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for Scaled_dot_product_flash_attention node " @@ -137,8 +205,37 @@ class SDPAFP8BackwardNode : public NodeCRTP { std::swap(temp_vec[2], temp_vec[3]); attributes.inputs[input_names::V]->set_stride(temp_vec); + std::shared_ptr rng_output; + auto mul_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL); + // if dropout_prob is used, then the node passes scale and scale inverse + // if dropout_mask is used, then the user passes scale and scale_inverse + bool is_dropout_prob = (attributes.dropout_probability.has_value()); + bool is_dropout_mask = (attributes.inputs[input_names::Dropout_mask] != nullptr); + if (is_dropout_prob) { + float dropout_scale_value = 1.0f / (1.0f - attributes.dropout_probability.value()); + float dropout_scale_inv_value = (1.0f - attributes.dropout_probability.value()); + + attributes.inputs[input_names::Dropout_scale] = std::make_shared(dropout_scale_value); + attributes.inputs[input_names::Dropout_scale_inv] = + std::make_shared(dropout_scale_inv_value); + } + + // --------------RNG node-------------------- + + if (is_dropout_prob) { + rng_output = rng(attributes.inputs[input_names::Seed], + attributes.inputs[input_names::Offset], + Rng_attributes() + .set_name("rng") + .set_distribution(RngDistribution_t::BERNOULLI) + .set_bernoulli_probability(1.0f - attributes.dropout_probability.value())); + rng_output->set_dim({b, h_q, s_q, s_kv}).set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); + } else if (is_dropout_mask) { + rng_output = attributes.inputs[input_names::Dropout_mask]; + } + //// dO * O mul_attributes.set_name("mul_dO_O"); auto last_output = @@ -156,11 +253,20 @@ class SDPAFP8BackwardNode : public NodeCRTP { // Descale O mul_attributes.set_name("descale_O"); - auto softmax_sum = pointwise(last_output, attributes.inputs.at(input_names::Descale_O), mul_attributes); - softmax_sum->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1}); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_O), mul_attributes); + + // softmax_sum = last_output * dropout_scale + if(attributes.inputs[input_names::Dropout_scale_inv]) { + last_output = pointwise(last_output, + attributes.inputs[input_names::Dropout_scale_inv], + Pointwise_attributes().set_name("scale_dropout_inv").set_mode(PointwiseMode_t::MUL)); + } + auto softmax_sum = last_output; //// Q * K - auto bmm_Q_K_attributes = Matmul_attributes().set_name("bmm_Q_K"); + auto bmm_Q_K_attributes = Matmul_attributes().set_name("bmm_Q_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV]); auto last_dV = matmul(attributes.inputs[input_names::Q], attributes.inputs[input_names::K], bmm_Q_K_attributes); //// Optional Attn scale @@ -171,9 +277,9 @@ class SDPAFP8BackwardNode : public NodeCRTP { } // If attn scale present, add a pointwise mul node - if (attributes.inputs[input_names::Attn_scale]) { + if (auto attn_scale_it = attributes.inputs.find(input_names::Attn_scale); attn_scale_it != attributes.inputs.end()) { mul_attributes.set_name("attn_scale"); - last_dV = pointwise(last_dV, attributes.inputs[input_names::Attn_scale], mul_attributes); + last_dV = pointwise(last_dV, attn_scale_it->second, mul_attributes); } //// Descales @@ -185,6 +291,63 @@ class SDPAFP8BackwardNode : public NodeCRTP { mul_attributes.set_name("descale_k"); last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Descale_K), mul_attributes); + // (optional) last_dV = last_dV + bias + if (auto bias_it = attributes.inputs.find(input_names::Bias); bias_it != attributes.inputs.end()) { + last_dV = pointwise(last_dV, + bias_it->second, + Pointwise_attributes().set_name("add_bias").set_mode(PointwiseMode_t::ADD)); + } + + // (optional) Apply padding mask + if (attributes.padding_mask) { + auto row_idx_output = pointwise(last_dV, + Pointwise_attributes() + .set_name("gen_row_idx_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_idx_output->set_data_type(DataType_t::INT32); + + auto col_idx_output = pointwise(last_dV, + Pointwise_attributes() + .set_name("gen_col_idx_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_idx_output->set_data_type(DataType_t::INT32); + + auto row_mask_output = pointwise(row_idx_output, + attributes.inputs[input_names::SEQ_LEN_Q], + Pointwise_attributes() + .set_name("lt_row_sq_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + row_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto col_mask_output = pointwise(col_idx_output, + attributes.inputs[input_names::SEQ_LEN_KV], + Pointwise_attributes() + .set_name("lt_col_skv_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + col_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto padding_mask_output = pointwise(row_mask_output, + col_mask_output, + Pointwise_attributes() + .set_name("and_row_col_padding") + .set_mode(PointwiseMode_t::LOGICAL_AND) + .set_compute_data_type(DataType_t::BOOLEAN)); + padding_mask_output->set_data_type(DataType_t::BOOLEAN); + auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); + + last_dV = + pointwise(last_dV, + negative_inf_padding, + padding_mask_output, + Pointwise_attributes().set_name("select_padding").set_mode(PointwiseMode_t::BINARY_SELECT)); + } + //// Optional causal masking if (attributes.causal_mask) { auto row_index_attributes = @@ -221,6 +384,22 @@ class SDPAFP8BackwardNode : public NodeCRTP { last_dV = pointwise(last_dV, Pointwise_attributes().set_name("exp_dV").set_mode(PointwiseMode_t::EXP)); auto exp_S = last_dV; + // (optional) last_dV = last_dV * dropout rng_output + if (is_dropout_prob || is_dropout_mask) { + last_dV = + pointwise(last_dV, + rng_output, + Pointwise_attributes().set_name("mul_p_dropout_mask").set_mode(PointwiseMode_t::MUL)); + } + + // (optional) last_dV = last_dV * dropout_scale + if (attributes.inputs[input_names::Dropout_scale]) { + last_dV = + pointwise(last_dV, + attributes.inputs[input_names::Dropout_scale], + Pointwise_attributes().set_name("mul_dS_dropout_scale").set_mode(PointwiseMode_t::MUL)); + } + // Scale S mul_attributes.set_name("scale_S"); last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Scale_S), mul_attributes); @@ -238,12 +417,16 @@ class SDPAFP8BackwardNode : public NodeCRTP { attributes.inputs[input_names::Descale_S], attributes.inputs[input_names::Descale_dO], attributes.inputs[input_names::Scale_dV], - Matmul_fp8_attributes().set_name("bmm_S_T_dO"), + Matmul_fp8_attributes().set_name("bmm_S_T_dO") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q]), attributes.outputs[output_names::dV], attributes.outputs[output_names::Amax_dV]); //// dO * V_T - auto bmm_dO_V_T_attributes = Matmul_attributes().set_name("bmm_dO_V_T"); + auto bmm_dO_V_T_attributes = Matmul_attributes().set_name("bmm_dO_V_T") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV]); last_output = matmul(attributes.inputs[input_names::dO], attributes.inputs[input_names::V], bmm_dO_V_T_attributes); @@ -265,10 +448,24 @@ class SDPAFP8BackwardNode : public NodeCRTP { mul_attributes.set_name("mul_dP_exp_S"); dP = pointwise(dP, exp_S, mul_attributes); + // (optional) dP = dP * dropout_scale + if (attributes.inputs[input_names::Dropout_scale]) { + dP = + pointwise(dP, + attributes.inputs[input_names::Dropout_scale], + Pointwise_attributes().set_name("mul_dS_dropout_scale").set_mode(PointwiseMode_t::MUL)); + } + + // if (attributes.outputs[output_names::dBias]) { + // reduction(dP, + // Reduction_attributes().set_name("red_dP_dBias").set_mode(ReductionMode_t::ADD), + // attributes.outputs[output_names::dBias]); + // } + // (optional) dP = dP * attn_scale - if (attributes.inputs[input_names::Attn_scale]) { + if (auto attn_scale_it = attributes.inputs.find(input_names::Attn_scale); attn_scale_it != attributes.inputs.end()) { mul_attributes.set_name("mul_dS_attn_scale"); - dP = pointwise(dP, attributes.inputs[input_names::Attn_scale], mul_attributes); + dP = pointwise(dP, attn_scale_it->second, mul_attributes); } // Amax dP @@ -289,7 +486,9 @@ class SDPAFP8BackwardNode : public NodeCRTP { K->set_dim({kt_dim[0], kt_dim[1], kt_dim[3], kt_dim[2]}) .set_stride({kt_stride[0], kt_stride[1], kt_stride[3], kt_stride[2]}); - auto bmm_dP_K_attributes = Matmul_fp8_attributes().set_name("bmm_dP_K"); + auto bmm_dP_K_attributes = Matmul_fp8_attributes().set_name("bmm_dP_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]); // Special non-functional-style call. Needed because output already created and provided to user. matmul_fp8(dP, K, @@ -306,7 +505,9 @@ class SDPAFP8BackwardNode : public NodeCRTP { dP_T->set_data_type(attributes.inputs.at(input_names::dO)->get_data_type()); dP_T->set_name("dP_T").set_dim({b, h_q, s_kv, s_q}).set_stride({h_q * s_q * s_kv, s_q * s_kv, 1, s_kv}); - auto bmm_dP_T_Q_attributes = Matmul_fp8_attributes().set_name("bmm_dP_T_Q"); + auto bmm_dP_T_Q_attributes = Matmul_fp8_attributes().set_name("bmm_dP_T_Q") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q]); // Special non-functional-style call. Needed because output already created and provided to user. matmul_fp8(dP_T, attributes.inputs[input_names::Q], diff --git a/include/cudnn_frontend/node/softmax.h b/include/cudnn_frontend/node/softmax.h index 3730465c..b2ba31b2 100644 --- a/include/cudnn_frontend/node/softmax.h +++ b/include/cudnn_frontend/node/softmax.h @@ -88,9 +88,6 @@ class SoftmaxNode : public NodeCRTP { auto add_attributes = Pointwise_attributes().set_name("add").set_mode(PointwiseMode_t::ADD); // Special non-functional-style call. Needed because output already created and provided to user. - attributes.outputs[Softmax_attributes::output_names::Stats] - ->set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); pointwise( max_output, log_output, add_attributes, attributes.outputs[Softmax_attributes::output_names::Stats]); } diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h index d149b289..7e39cbb7 100644 --- a/include/cudnn_frontend/node_interface.h +++ b/include/cudnn_frontend/node_interface.h @@ -319,7 +319,7 @@ class INode : public ICudnn { serialize(json& j) const = 0; #endif - size_t + virtual size_t key() { #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j; diff --git a/include/cudnn_frontend/plans.h b/include/cudnn_frontend/plans.h index a0819c06..155f27f8 100644 --- a/include/cudnn_frontend/plans.h +++ b/include/cudnn_frontend/plans.h @@ -129,10 +129,11 @@ inline error_t create_cudnn_execution_plan(std::shared_ptr& plan, ManagedOpaqueDescriptor const& config, std::string const& operation_graph_tag, + std::shared_ptr kernel_cache, cudnnHandle_t handle) { auto&& plan_builder = cudnn_frontend::ExecutionPlanBuilder(); - plan_builder.setHandle(handle).setEngineConfig(config, operation_graph_tag); + plan_builder.setHandle(handle).setEngineConfig(config, operation_graph_tag).setKernelCache(kernel_cache); #ifdef NV_CUDNN_DISABLE_EXCEPTION // disable exception macro is defined. Calling build will not throw. @@ -171,6 +172,7 @@ class Execution_plan_list { std::vector> numeric_notes; std::vector> behavior_notes; std::vector barred_indices; + std::shared_ptr kernel_cache; int64_t max_workspace_allowed = std::numeric_limits::max(); int64_t max_shared_mem_allowed = 1024 * 1024 * 1024; // Crazy high number (2GB) which will never be hit @@ -182,7 +184,7 @@ class Execution_plan_list { _build_plan_at_index_impl(cudnnHandle_t handle, int64_t index) { if (execution_plans[index] == nullptr) { CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_execution_plan( - execution_plans[index], engine_configs[index], operation_tag, handle)); + execution_plans[index], engine_configs[index], operation_tag, kernel_cache, handle)); } auto is_blocked = [](std::string const& full_name, std::vector const& blocked_names) -> bool { @@ -233,6 +235,10 @@ class Execution_plan_list { set_engine_configs(EngineConfigList list) { engine_configs = list; } + void + set_kernel_cache(std::shared_ptr kernel_cache_) { + kernel_cache = kernel_cache_; + } std::vector>& get_execution_plans() { @@ -373,6 +379,12 @@ class Execution_plan_list { return barred_engine_configs; } + error_t + get_name_at_index(int64_t index, std::string& name) const { + name = detail::get_engine_tag(engine_configs[index]); + return {error_code_t::OK, ""}; + } + error_t check_support_at_index(cudnnHandle_t handle, int64_t index) { // Ignore if the engine config was deselected. diff --git a/include/cudnn_frontend/utils/serialize.h b/include/cudnn_frontend/utils/serialize.h index b242e693..b89a0986 100644 --- a/include/cudnn_frontend/utils/serialize.h +++ b/include/cudnn_frontend/utils/serialize.h @@ -220,6 +220,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Matmul_fp8_attributes::input_names, {Matmul_fp8_attributes::input_names::B, "B"}, {Matmul_fp8_attributes::input_names::Descale_A, "Descale_A"}, {Matmul_fp8_attributes::input_names::Descale_B, "Descale_B"}, + {Matmul_fp8_attributes::input_names::M_override, "M_override"}, + {Matmul_fp8_attributes::input_names::N_override, "N_override"}, + {Matmul_fp8_attributes::input_names::K_override, "K_override"}, {Matmul_fp8_attributes::input_names::Scale_C, "Scale_C"}, }) @@ -327,6 +330,14 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_attributes::input_names, {SDPA_fp8_attributes::input_names::K, "K"}, {SDPA_fp8_attributes::input_names::V, "V"}, {SDPA_fp8_attributes::input_names::Attn_scale, "Attn_scale"}, + {SDPA_fp8_attributes::input_names::Bias, "Bias"}, + {SDPA_fp8_attributes::input_names::SEQ_LEN_Q, "SEQ_LEN_Q"}, + {SDPA_fp8_attributes::input_names::SEQ_LEN_KV, "SEQ_LEN_KV"}, + {SDPA_fp8_attributes::input_names::Seed, "Seed"}, + {SDPA_fp8_attributes::input_names::Offset, "Offset"}, + {SDPA_fp8_attributes::input_names::Dropout_mask, "Dropout_mask"}, + {SDPA_fp8_attributes::input_names::Dropout_scale, "Dropout_scale"}, + {SDPA_fp8_attributes::input_names::Descale_Q, "Descale_Q"}, {SDPA_fp8_attributes::input_names::Descale_K, "Descale_K"}, {SDPA_fp8_attributes::input_names::Descale_V, "Descale_V"}, @@ -389,6 +400,15 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_backward_attributes::input_names, {SDPA_fp8_backward_attributes::input_names::dO, "dO"}, {SDPA_fp8_backward_attributes::input_names::Stats, "Stats"}, {SDPA_fp8_backward_attributes::input_names::Attn_scale, "Attn_scale"}, + {SDPA_fp8_backward_attributes::input_names::Bias, "Bias"}, + {SDPA_fp8_backward_attributes::input_names::SEQ_LEN_Q, "SEQ_LEN_Q"}, + {SDPA_fp8_backward_attributes::input_names::SEQ_LEN_KV, "SEQ_LEN_KV"}, + {SDPA_fp8_backward_attributes::input_names::Seed, "Seed"}, + {SDPA_fp8_backward_attributes::input_names::Offset, "Offset"}, + {SDPA_fp8_backward_attributes::input_names::Dropout_mask, "Dropout_mask"}, + {SDPA_fp8_backward_attributes::input_names::Dropout_scale, "Dropout_scale"}, + {SDPA_fp8_backward_attributes::input_names::Dropout_scale_inv, "Dropout_scale_inv"}, + {SDPA_fp8_backward_attributes::input_names::Descale_Q, "Descale_Q"}, {SDPA_fp8_backward_attributes::input_names::Descale_K, "Descale_K"}, {SDPA_fp8_backward_attributes::input_names::Descale_V, "Descale_V"}, diff --git a/include/cudnn_frontend_ExecutionPlan.h b/include/cudnn_frontend_ExecutionPlan.h index 589aa94c..22bbc676 100644 --- a/include/cudnn_frontend_ExecutionPlan.h +++ b/include/cudnn_frontend_ExecutionPlan.h @@ -33,6 +33,7 @@ #include "cudnn_frontend_EngineConfig.h" #include "cudnn_frontend_Engine.h" #include "cudnn_frontend_utils.h" +#include "cudnn_frontend/backend/kernel_cache.h" namespace cudnn_frontend { /// @@ -321,7 +322,8 @@ class ExecutionPlan_v8 : public BackendDescriptor { std::array behavior_notes; std::vector behavior_notes_vec; - float execution_time_ms = 0.0f; + float execution_time_ms = 0.0f; + std::shared_ptr kernel_cache = nullptr; }; /// @@ -347,6 +349,12 @@ class ExecutionPlanBuilder_v8 { return *this; } + auto + setKernelCache(std::shared_ptr kernel_cache) -> ExecutionPlanBuilder_v8 & { + m_execution_plan.kernel_cache = kernel_cache; + return *this; + } + //! Set engine Config for the Plan auto setEngineConfig(ManagedOpaqueDescriptor &desc, std::string const &opGraphTag_ = "") -> ExecutionPlanBuilder_v8 & { @@ -415,6 +423,22 @@ class ExecutionPlanBuilder_v8 { "CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: SetAttribute CUDNN_ATTR_EXECUTION_PLAN_HANDLE Failed"); return std::move(m_execution_plan); } +#if (CUDNN_VERSION >= 90400) + if (m_execution_plan.kernel_cache) { + status = detail::set_attribute(m_execution_plan.pointer->get_backend_descriptor(), + CUDNN_ATTR_EXECUTION_PLAN_KERNEL_CACHE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &m_execution_plan.kernel_cache->get_ptr()); + if (status != CUDNN_STATUS_SUCCESS) { + set_error_and_throw_exception(&m_execution_plan, + status, + "CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: SetAttribute " + "CUDNN_ATTR_EXECUTION_PLAN_KERNEL_CACHE Failed"); + return std::move(m_execution_plan); + } + } +#endif // Finalizing the descriptor status = detail::finalize(m_execution_plan.pointer->get_backend_descriptor()); if (status != CUDNN_STATUS_SUCCESS) { diff --git a/include/cudnn_frontend_OperationGraph.h b/include/cudnn_frontend_OperationGraph.h index 837ffd56..48a19f42 100644 --- a/include/cudnn_frontend_OperationGraph.h +++ b/include/cudnn_frontend_OperationGraph.h @@ -32,7 +32,6 @@ #include "cudnn_frontend_Operation.h" #include "cudnn_frontend_utils.h" - // Compile time constant for max ops in a op graph constexpr int64_t MAX_OPGRAPH_OPS = 50; @@ -131,6 +130,7 @@ class OperationGraph_v8 : public BackendDescriptor { int64_t numOps = -1; std::string opGraphTag = ""; std::vector feature_vectors; + bool is_dynamic_shape_enabled = false; }; /// @@ -182,6 +182,12 @@ class OperationGraphBuilder_v8 { } /** @} */ + auto + setIsDynamicShapeEnabled(bool is_enabled) -> OperationGraphBuilder_v8 & { + m_operationGraph.is_dynamic_shape_enabled = is_enabled; + return *this; + } + //! constructs the OperationGraph_v8 by calling the cudnn API //! Throws the appropriate error message OperationGraph_v8 && @@ -245,6 +251,22 @@ class OperationGraphBuilder_v8 { "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute CUDNN_ATTR_OPERATIONGRAPH_HANDLE Failed"); return std::move(m_operationGraph); } +#if (CUDNN_VERSION >= 90400) + if (m_operationGraph.is_dynamic_shape_enabled) { + status = detail::set_attribute(m_operationGraph.pointer->get_backend_descriptor(), + CUDNN_ATTR_OPERATIONGRAPH_IS_DYNAMIC_SHAPE_ENABLED, + CUDNN_TYPE_BOOLEAN, + 1, + &m_operationGraph.is_dynamic_shape_enabled); + if (status != CUDNN_STATUS_SUCCESS) { + set_error_and_throw_exception(&m_operationGraph, + status, + "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute " + "CUDNN_ATTR_OPERATIONGRAPH_IS_DYNAMIC_SHAPE_ENABLED Failed"); + return std::move(m_operationGraph); + } + } +#endif // Finalizing the descriptor status = detail::finalize(m_operationGraph.pointer->get_backend_descriptor()); diff --git a/include/cudnn_frontend_shim.h b/include/cudnn_frontend_shim.h index f277a89e..9c12a9ef 100644 --- a/include/cudnn_frontend_shim.h +++ b/include/cudnn_frontend_shim.h @@ -273,7 +273,7 @@ get_error_string(cudnnStatus_t status) { inline void get_last_error_string(char *message, size_t size) { - if (detail::get_backend_version() > 90000 && detail::get_compiled_version() < 90000) { + if (detail::get_backend_version() >= 90000 && detail::get_compiled_version() >= 90000) { #if CUDNN_VERSION >= 90000 NV_FE_CALL_TO_BACKEND(get_last_error_string, cudnnGetLastErrorString, message, size); #endif diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h index f53d4df2..6ead3a19 100644 --- a/include/cudnn_frontend_utils.h +++ b/include/cudnn_frontend_utils.h @@ -104,7 +104,7 @@ struct nlohmann::adl_serializer> static void from_json(const nlohmann::json& j, std::variant& data) { if (!j.is_object() || !j.contains("index") || !j.contains("value")) { - throw std::invalid_argument("Invalid JSON format for std::variant"); + return; } size_t type_index = j.at("index").get(); @@ -117,7 +117,7 @@ struct nlohmann::adl_serializer> } else if (type_index == 3) { data = j.at("value").get(); } else { - throw std::out_of_range("Variant index out of range"); + return; } } }; diff --git a/include/cudnn_frontend_version.h b/include/cudnn_frontend_version.h index 11ad9dad..002d7482 100644 --- a/include/cudnn_frontend_version.h +++ b/include/cudnn_frontend_version.h @@ -23,7 +23,7 @@ #pragma once #define CUDNN_FRONTEND_MAJOR_VERSION 1 -#define CUDNN_FRONTEND_MINOR_VERSION 6 -#define CUDNN_FRONTEND_PATCH_VERSION 1 +#define CUDNN_FRONTEND_MINOR_VERSION 7 +#define CUDNN_FRONTEND_PATCH_VERSION 0 #define CUDNN_FRONTEND_VERSION \ ((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION) diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py index 64ba2afd..fa823359 100644 --- a/python/cudnn/__init__.py +++ b/python/cudnn/__init__.py @@ -25,7 +25,7 @@ from .datatypes import _library_type, _is_torch_tensor -__version__ = "1.6.1" +__version__ = "1.7.0" def _tensor( diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp index 67cbf5d0..4575531f 100644 --- a/python/pygraph/pygraph.cpp +++ b/python/pygraph/pygraph.cpp @@ -363,7 +363,22 @@ PyGraph::check_support() { int64_t PyGraph::get_workspace_size() { - return graph.get_workspace_size(); + int64_t workspace = 0; + + auto status = graph.get_workspace_size(workspace); + throw_if(status.is_bad(), status.get_code(), status.get_message()); + + return workspace; +} + +int64_t +PyGraph::get_workspace_size_plan_at_index(int64_t index) { + int64_t workspace; + + auto status = graph.get_workspace_size_plan_at_index(index, workspace); + throw_if(status.is_bad(), status.get_code(), status.get_message()); + + return workspace; } std::vector @@ -430,6 +445,14 @@ PyGraph::execute_plan_at_index(std::unordered_map var_pa return; } +std::shared_ptr +PyGraph::query_tensor_attributes_of_uid(int64_t const uid) const { + graph::Tensor_attributes tensor; + auto status = graph.query_tensor_attributes_of_uid(uid, tensor); + throw_if(status.is_bad(), status.get_code(), status.get_message()); + return std::make_shared(tensor); +} + std::vector default_vector(void) { return {}; @@ -729,6 +752,15 @@ init_pygraph_submodule(py::module_& m) { index (int): The index of the plan to get workspace from. If the graph is not built at the index, this will return 0. )pbdoc") + .def("query_tensor_attributes_of_uid", + &PyGraph::query_tensor_attributes_of_uid, + py::arg("uid"), + R"pbdoc( + Get tensor_attributes for a given UID + Args: + uid (int): The uid of tensor to be queried + If the graph does not have the UID, this will raise an error + )pbdoc") .def("_execute", &PyGraph::execute) .def("serialize", &PyGraph::serialize) .def("deserialize", &PyGraph::deserialize) diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h index c73e3519..55667f03 100644 --- a/python/pygraph/pygraph.h +++ b/python/pygraph/pygraph.h @@ -432,9 +432,10 @@ class PyGraph { } int64_t - get_workspace_size_plan_at_index(int64_t index) const { - return graph.get_workspace_size_plan_at_index(index); - } + get_workspace_size_plan_at_index(int64_t index); + + std::shared_ptr + query_tensor_attributes_of_uid(int64_t const uid) const; }; } // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/samples/cpp/convolution/dgrads.cpp b/samples/cpp/convolution/dgrads.cpp index c1f23797..3265e007 100644 --- a/samples/cpp/convolution/dgrads.cpp +++ b/samples/cpp/convolution/dgrads.cpp @@ -65,7 +65,11 @@ TEST_CASE("Convolution Dgrad", "[dgrad][graph]") { Surface w_tensor(64 * 32 * 3 * 3, false); Surface dx_tensor(4 * 32 * 16 * 16, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {DY, dy_tensor.devPtr}, {W, w_tensor.devPtr}, {DX, dx_tensor.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -118,7 +122,10 @@ TEST_CASE("Dgrad Drelu Graph", "[dgrad][graph]") { Surface x_tensor(4 * 32 * 16 * 16, false); Surface dx_tensor(4 * 32 * 16 * 16, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {DY, dy_tensor.devPtr}, {W, w_tensor.devPtr}, {X, x_tensor.devPtr}, {DX, dx_tensor.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -227,7 +234,10 @@ TEST_CASE("Dgrad Drelu DBNweight Graph", "[dgrad][graph]") { Surface eq_scale_x_tensor(1 * 32 * 1 * 1, false); Surface eq_bias_tensor(1 * 32 * 1 * 1, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {DY, dy_tensor.devPtr}, {W, w_tensor.devPtr}, diff --git a/samples/cpp/convolution/fp8_fprop.cpp b/samples/cpp/convolution/fp8_fprop.cpp index e9785820..4d0b6efe 100644 --- a/samples/cpp/convolution/fp8_fprop.cpp +++ b/samples/cpp/convolution/fp8_fprop.cpp @@ -115,7 +115,10 @@ TEST_CASE("Convolution fp8 precision", "[conv][graph]") { Surface Y_scale_gpu(1, false); Surface amax_gpu(1, false); - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_gpu.devPtr}, {W, W_gpu.devPtr}, diff --git a/samples/cpp/convolution/fprop.cpp b/samples/cpp/convolution/fprop.cpp index 3fe139f9..73493cc0 100644 --- a/samples/cpp/convolution/fprop.cpp +++ b/samples/cpp/convolution/fprop.cpp @@ -77,10 +77,12 @@ TEST_CASE("Convolution fprop", "[conv][graph][caching]") { Surface w_tensor(k * c * r * s, false); Surface y_tensor(n * k * h * w, false); // Should be p, q. - std::unordered_map variant_pack = { + std::unordered_map variant_pack = { {X->get_uid(), x_tensor.devPtr}, {W->get_uid(), w_tensor.devPtr}, {Y->get_uid(), y_tensor.devPtr}}; - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); std::cout << *graph << std::endl; @@ -88,6 +90,129 @@ TEST_CASE("Convolution fprop", "[conv][graph][caching]") { cudnnDestroy(handle); } +TEST_CASE("Convolution fprop dynamic shape", "[conv][graph][dynamic_shape]") { + namespace fe = cudnn_frontend; + + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + + // clang-format off + struct { + int64_t n, c, h, w, k, r, s; + } conv_shapes[] = { + { 16, 128, 56, 56, 256, 3, 3}, + { 16, 128, 64, 64, 256, 3, 3}, + { 16, 128, 80, 64, 256, 3, 3}, + { 32, 128, 80, 80, 256, 3, 3}, + { 32, 256, 32, 32, 256, 3, 3}, + }; + // clang-format on + + constexpr int conv_shapes_count = sizeof(conv_shapes) / sizeof(conv_shapes[0]); + int64_t max_x_volume = 0, max_w_volume = 0, max_y_volume = 0; + for (int idx_shape = 0; idx_shape < conv_shapes_count; ++idx_shape) { + const auto &conv_shape = conv_shapes[idx_shape]; + max_x_volume = std::max(max_x_volume, conv_shape.n * conv_shape.c * conv_shape.h * conv_shape.w); + max_w_volume = std::max(max_w_volume, conv_shape.k * conv_shape.c * conv_shape.r * conv_shape.s); + max_y_volume = std::max(max_y_volume, conv_shape.n * conv_shape.k * conv_shape.h * conv_shape.w); + } + + auto kernel_cache = std::make_shared(); + + const auto build_new_graph = [&conv_shapes, &kernel_cache](cudnnHandle_t handle, int idx_shape) { + const auto &conv_shape = conv_shapes[idx_shape]; + + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::HALF) + .set_compute_data_type(fe::DataType_t::FLOAT) + .set_dynamic_shape_enabled(true) + .set_kernel_cache(kernel_cache); + + auto X = graph->tensor( + fe::graph::Tensor_attributes() + .set_name("image") + .set_dim({conv_shape.n, conv_shape.c, conv_shape.h, conv_shape.w}) + .set_stride( + {conv_shape.c * conv_shape.h * conv_shape.w, 1, conv_shape.c * conv_shape.w, conv_shape.c})); + + auto W = graph->tensor( + fe::graph::Tensor_attributes() + .set_name("filter") + .set_dim({conv_shape.k, conv_shape.c, conv_shape.r, conv_shape.s}) + .set_stride( + {conv_shape.c * conv_shape.r * conv_shape.s, 1, conv_shape.c * conv_shape.s, conv_shape.c})); + + auto conv_options = fe::graph::Conv_fprop_attributes() + .set_pre_padding({1, 1}) // padding such that P=H, Q=W + .set_post_padding({0, 0}) + .set_stride({1, 1}) + .set_dilation({1, 1}); + + auto Y1 = graph->conv_fprop(X, W, conv_options); + Y1->set_data_type(fe::DataType_t::HALF); + + auto Y = graph->pointwise(Y1, + fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::RELU_FWD) + .set_compute_data_type(fe::DataType_t::FLOAT)); + + Y->set_output(true); + auto status = graph->validate(); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Dynamic shapes not supported pre 9.4"); + } + + status = graph->build_operation_graph(handle); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Kernel cache not supported pre 9.4"); + } + + REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph->check_support(handle).is_good()); + + REQUIRE(graph->build_plans(handle).is_good()); + + return std::make_tuple(graph, X, W, Y); + }; + + const auto execute_graph = [&max_x_volume, &max_w_volume, &max_y_volume](cudnnHandle_t handle, + const fe::graph::Graph *graph, + const fe::graph::Tensor_attributes *X, + const fe::graph::Tensor_attributes *W, + const fe::graph::Tensor_attributes *Y) { + Surface x_tensor(max_x_volume, false); + Surface w_tensor(max_w_volume, false); + Surface y_tensor(max_y_volume, false); + + std::unordered_map variant_pack = { + {X->get_uid(), x_tensor.devPtr}, {W->get_uid(), w_tensor.devPtr}, {Y->get_uid(), y_tensor.devPtr}}; + + Surface workspace(graph->get_workspace_size(), false); + + std::cout << *graph << std::endl; + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + }; + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + for (int idx_shape = 0; idx_shape < conv_shapes_count; ++idx_shape) { + auto [graph, X, W, Y] = build_new_graph(handle, idx_shape); + execute_graph(handle, graph.get(), X.get(), W.get(), Y.get()); + } + + cudnnDestroy(handle); +} + TEST_CASE("CSBR Graph", "[conv][graph][caching]") { namespace fe = cudnn_frontend; @@ -178,19 +303,22 @@ TEST_CASE("CSBR Graph", "[conv][graph][caching]") { Surface b_tensor(k, false); Surface y_tensor(n * k * h * w, false); // Should be p, q. - Surface workspace(graph->get_workspace_size(), false); - std::unordered_map, void*> variant_pack = { + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + std::unordered_map, void *> variant_pack = { {X, x_tensor.devPtr}, {W, w_tensor.devPtr}, {S, s_tensor.devPtr}, {B, b_tensor.devPtr}, {Y, y_tensor.devPtr}}; REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); auto [graph_, X_, W_, B_, S_, Y_] = lookup_cache_or_build_graph(handle); - std::unordered_map, void*> variant_pack_ = {{X_, x_tensor.devPtr}, - {W_, w_tensor.devPtr}, - {S_, s_tensor.devPtr}, - {B_, b_tensor.devPtr}, - {Y_, y_tensor.devPtr}}; + std::unordered_map, void *> variant_pack_ = {{X_, x_tensor.devPtr}, + {W_, w_tensor.devPtr}, + {S_, s_tensor.devPtr}, + {B_, b_tensor.devPtr}, + {Y_, y_tensor.devPtr}}; REQUIRE(graph_->execute(handle, variant_pack_, workspace.devPtr).is_good()); @@ -199,6 +327,136 @@ TEST_CASE("CSBR Graph", "[conv][graph][caching]") { cudnnDestroy(handle); } +TEST_CASE("CSBR Graph dynamic shape", "[conv][graph][dynamic_shape]") { + namespace fe = cudnn_frontend; + + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + + // clang-format off + struct { + int64_t n, c, h, w, k, r, s; + } conv_shapes[] = { + { 8, 32, 16, 16, 64, 3, 3}, + { 8, 32, 24, 24, 64, 3, 3}, + { 16, 32, 32, 32, 64, 3, 3}, + { 16, 64, 32, 32, 64, 3, 3}, + { 16, 16, 64, 64, 16, 3, 3}, + }; + // clang-format on + + constexpr int conv_shapes_count = sizeof(conv_shapes) / sizeof(conv_shapes[0]); + int64_t max_x_volume = 0, max_w_volume = 0, max_y_volume = 0, max_k = 0; + for (int idx_shape = 0; idx_shape < conv_shapes_count; ++idx_shape) { + const auto &conv_shape = conv_shapes[idx_shape]; + max_x_volume = std::max(max_x_volume, conv_shape.n * conv_shape.c * conv_shape.h * conv_shape.w); + max_w_volume = std::max(max_w_volume, conv_shape.k * conv_shape.c * conv_shape.r * conv_shape.s); + max_y_volume = std::max(max_y_volume, conv_shape.n * conv_shape.k * conv_shape.h * conv_shape.w); + max_k = std::max(max_k, conv_shape.k); + } + + auto kernel_cache = std::make_shared(); + + auto lookup_cache_or_build_graph = [&conv_shapes, &kernel_cache](cudnnHandle_t handle, int idx_shape) { + const auto &conv_shape = conv_shapes[idx_shape]; + + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT) + .set_dynamic_shape_enabled(true) + .set_kernel_cache(kernel_cache); + + auto X = graph->tensor( + fe::graph::Tensor_attributes() + .set_name("image") + .set_dim({conv_shape.n, conv_shape.c, conv_shape.h, conv_shape.w}) + .set_stride( + {conv_shape.c * conv_shape.h * conv_shape.w, 1, conv_shape.c * conv_shape.w, conv_shape.c})); + + auto W = graph->tensor( + fe::graph::Tensor_attributes() + .set_name("filter") + .set_dim({conv_shape.k, conv_shape.c, conv_shape.r, conv_shape.s}) + .set_stride( + {conv_shape.c * conv_shape.r * conv_shape.s, 1, conv_shape.c * conv_shape.s, conv_shape.c})); + + auto conv_options = + fe::graph::Conv_fprop_attributes().set_padding({1, 1}).set_stride({1, 1}).set_dilation({1, 1}); + + auto conv_output = graph->conv_fprop(X, W, conv_options); + + auto S = graph->tensor(fe::graph::Tensor_attributes() + .set_name("scale") + .set_dim({1, conv_shape.k, 1, 1}) + .set_stride({conv_shape.k, 1, conv_shape.k, conv_shape.k})); + + auto scale_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::MUL); + auto scale_output = graph->pointwise(conv_output, S, scale_options); + + auto B = graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({1, conv_shape.k, 1, 1}) + .set_stride({conv_shape.k, 1, conv_shape.k, conv_shape.k})); + + auto bias_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD); + auto bias_output = graph->pointwise(scale_output, B, bias_options); + + auto relu_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::RELU_FWD); + auto Y = graph->pointwise(bias_output, relu_options); + Y->set_output(true); + + auto status = graph->validate(); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Dynamic shapes not supported pre 9.4"); + } + + status = graph->build_operation_graph(handle); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Kernel cache not supported pre 9.4"); + } + + REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph->check_support(handle).is_good()); + + REQUIRE(graph->build_plans(handle).is_good()); + + return std::make_tuple(graph, X, W, S, B, Y); + }; + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + for (int idx_shape = 0; idx_shape < conv_shapes_count; idx_shape++) { + auto [graph, X, W, B, S, Y] = lookup_cache_or_build_graph(handle, idx_shape); + + Surface x_tensor(max_x_volume, false); + Surface w_tensor(max_w_volume, false); + Surface s_tensor(max_k, false); + Surface b_tensor(max_k, false); + Surface y_tensor(max_y_volume, false); // Should be p, q. + + Surface workspace(graph->get_workspace_size(), false); + std::unordered_map, void *> variant_pack = {{X, x_tensor.devPtr}, + {W, w_tensor.devPtr}, + {S, s_tensor.devPtr}, + {B, b_tensor.devPtr}, + {Y, y_tensor.devPtr}}; + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + } + + cudnnDestroy(handle); +} + TEST_CASE("SBRCS", "[conv][genstats][graph]") { namespace fe = cudnn_frontend; @@ -279,7 +537,7 @@ TEST_CASE("SBRCS", "[conv][genstats][graph]") { Surface sum_tensor(k, false); Surface sq_sum_tensor(k, false); - std::unordered_map, void*> variant_pack = { + std::unordered_map, void *> variant_pack = { {X, x_tensor.devPtr}, {S, s_tensor.devPtr}, {B, b_tensor.devPtr}, @@ -288,7 +546,10 @@ TEST_CASE("SBRCS", "[conv][genstats][graph]") { {SUM, sum_tensor.devPtr}, {SQ_SUM, sq_sum_tensor.devPtr}}; - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); cudnnDestroy(handle); } @@ -386,19 +647,22 @@ TEST_CASE("CBR Graph NCHW", "[conv][graph][caching]") { Surface y_tensor(n * k * h * w, false); // Should be p, q. Surface z_tensor(n * k * h * w, false); // Should be p, q. - Surface workspace(graph->get_workspace_size(), false); - std::unordered_map, void*> variant_pack = { + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + std::unordered_map, void *> variant_pack = { {X, x_tensor.devPtr}, {W, w_tensor.devPtr}, {B, b_tensor.devPtr}, {Z, z_tensor.devPtr}, {Y, y_tensor.devPtr}}; REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); auto [graph_, X_, W_, Z_, B_, Y_] = lookup_cache_or_build_graph(handle); - std::unordered_map, void*> variant_pack_ = {{X_, x_tensor.devPtr}, - {W_, w_tensor.devPtr}, - {B_, b_tensor.devPtr}, - {Z_, z_tensor.devPtr}, - {Y_, y_tensor.devPtr}}; + std::unordered_map, void *> variant_pack_ = {{X_, x_tensor.devPtr}, + {W_, w_tensor.devPtr}, + {B_, b_tensor.devPtr}, + {Z_, z_tensor.devPtr}, + {Y_, y_tensor.devPtr}}; REQUIRE(graph_->execute(handle, variant_pack_, workspace.devPtr).is_good()); @@ -463,10 +727,12 @@ TEST_CASE("Convolution fprop large", "[conv][graph][caching]") { Surface w_tensor(k * c * t * r * s, false); Surface y_tensor(n * k * d * h * w, false); // Should be p, q. - std::unordered_map variant_pack = { + std::unordered_map variant_pack = { {X->get_uid(), x_tensor.devPtr}, {W->get_uid(), w_tensor.devPtr}, {Y->get_uid(), y_tensor.devPtr}}; - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); std::cout << *graph << std::endl; diff --git a/samples/cpp/convolution/int8_fprop.cpp b/samples/cpp/convolution/int8_fprop.cpp index 7586d2ff..233e569f 100644 --- a/samples/cpp/convolution/int8_fprop.cpp +++ b/samples/cpp/convolution/int8_fprop.cpp @@ -94,7 +94,10 @@ TEST_CASE("Conv with Int8 datatypes", "[conv][graph][caching]") { std::unordered_map, void*> variant_pack = { {X, x_tensor.devPtr}, {W, w_tensor.devPtr}, {Y, y_tensor.devPtr}}; - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); cudnnDestroy(handle); } diff --git a/samples/cpp/convolution/wgrads.cpp b/samples/cpp/convolution/wgrads.cpp index 7aace2b7..12cb72ed 100644 --- a/samples/cpp/convolution/wgrads.cpp +++ b/samples/cpp/convolution/wgrads.cpp @@ -64,7 +64,10 @@ TEST_CASE("Convolution Wgrad", "[wgrad][graph][wgrad][Conv_wgrad]") { Surface dy_tensor(4 * 64 * 16 * 16, false); Surface dw_tensor(64 * 64 * 3 * 3, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, x_tensor.devPtr}, {DY, dy_tensor.devPtr}, {DW, dw_tensor.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -130,7 +133,10 @@ TEST_CASE("Wgrad Graph", "[wgrad][graph][scale-bias-relu-wgrad][ConvBNwgrad]") { Surface dy_tensor(4 * 64 * 16 * 16, false); Surface dw_tensor(64 * 64 * 3 * 3, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = {{X, x_tensor.devPtr}, {S, s_tensor.devPtr}, {B, b_tensor.devPtr}, diff --git a/samples/cpp/matmul/fp8_matmul.cpp b/samples/cpp/matmul/fp8_matmul.cpp index 9b334c36..62f63d79 100644 --- a/samples/cpp/matmul/fp8_matmul.cpp +++ b/samples/cpp/matmul/fp8_matmul.cpp @@ -115,7 +115,10 @@ TEST_CASE("Matmul fp8 precision", "[matmul][graph]") { REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); Surface C_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, diff --git a/samples/cpp/matmul/int8_matmul.cpp b/samples/cpp/matmul/int8_matmul.cpp index 4c551420..788b49f4 100644 --- a/samples/cpp/matmul/int8_matmul.cpp +++ b/samples/cpp/matmul/int8_matmul.cpp @@ -104,7 +104,10 @@ TEST_CASE("Int8 Matmul", "[matmul][graph]") { // note this is a bf16 tensor, but half is used just for memory allocation Surface C_gpu(b * m * n, false); Surface Bias_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C_after_add, C_gpu.devPtr}, {Bias, Bias_gpu.devPtr}}; diff --git a/samples/cpp/matmul/matmuls.cpp b/samples/cpp/matmul/matmuls.cpp index 33d4af47..ef79c429 100644 --- a/samples/cpp/matmul/matmuls.cpp +++ b/samples/cpp/matmul/matmuls.cpp @@ -28,9 +28,177 @@ #include +void +matmul_dynamic_shapes(bool use_abs = false, bool use_bias = false) { + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by current cudnn version"); + } + namespace fe = cudnn_frontend; + + // clang-format off + struct { + int64_t b, m, n, k; + } matmul_shapes[] = { + { 16, 32, 32, 128}, + { 16, 64, 64, 128}, + { 16, 80, 80, 128}, + { 32, 128, 128, 256}, + { 32, 64, 64, 256}, + }; + // clang-format on + + constexpr int matmul_shapes_count = sizeof(matmul_shapes) / sizeof(matmul_shapes[0]); + int64_t max_a_volume = 0, max_b_volume = 0, max_c_volume = 0, max_bias_volume = 0; + for (int idx_shape = 0; idx_shape < matmul_shapes_count; ++idx_shape) { + const auto& matmul_shape = matmul_shapes[idx_shape]; + max_a_volume = std::max(max_a_volume, matmul_shape.b * matmul_shape.m * matmul_shape.k); + max_b_volume = std::max(max_b_volume, matmul_shape.b * matmul_shape.k * matmul_shape.n); + max_c_volume = std::max(max_c_volume, matmul_shape.b * matmul_shape.m * matmul_shape.n); + max_bias_volume = std::max(max_bias_volume, matmul_shape.b * matmul_shape.m); + } + + auto kernel_cache = std::make_shared(); + + const auto build_new_graph = [&matmul_shapes, &kernel_cache, &use_abs, &use_bias](cudnnHandle_t handle, + int idx_shape) { + const auto& matmul_shape = matmul_shapes[idx_shape]; + + // Make cudnn graph + fe::graph::Graph graph{}; + + graph.set_dynamic_shape_enabled(true).set_kernel_cache(kernel_cache); + + // Create the two non-virtual input tensors A and B. + // There are read from global memory. + auto A_attributes = fe::graph::Tensor_attributes() + .set_name("A") + .set_dim({matmul_shape.b, matmul_shape.m, matmul_shape.k}) + .set_stride({matmul_shape.m * matmul_shape.k, matmul_shape.k, 1}) + .set_data_type(fe::DataType_t::BFLOAT16); + auto A = graph.tensor(A_attributes); + + auto B_attributes = fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({matmul_shape.b, matmul_shape.k, matmul_shape.n}) + .set_stride({matmul_shape.k * matmul_shape.n, matmul_shape.n, 1}) + .set_data_type(fe::DataType_t::BFLOAT16); + auto B = graph.tensor(B_attributes); + + auto matmul_attributes = fe::graph::Matmul_attributes().set_compute_data_type(fe::DataType_t::FLOAT); + + std::shared_ptr C; + std::shared_ptr Bias; + + if (use_abs) { + // Add abs operation + auto pw_0_attributes = fe::graph::Pointwise_attributes() + .set_name("pw0_Abs") + .set_mode(fe::PointwiseMode_t::ABS) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto A_after_pw_0 = graph.pointwise(A, pw_0_attributes); + A_after_pw_0->set_data_type(fe::DataType_t::BFLOAT16); + + C = graph.matmul(A_after_pw_0, B, matmul_attributes); + } else if (use_bias) { + // Create Bias vector + auto Bias_attributes = fe::graph::Tensor_attributes() + .set_name("Bias") + .set_dim({matmul_shape.b, matmul_shape.m, 1}) + .set_stride({matmul_shape.m, 1, 1}) + .set_data_type(fe::DataType_t::BFLOAT16); + Bias = graph.tensor(Bias_attributes); + + // Add ADD operation + auto pw_0_attributes = fe::graph::Pointwise_attributes() + .set_name("pw0_Add") + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto A_after_pw_0 = graph.pointwise(A, Bias, pw_0_attributes); + A_after_pw_0->set_data_type(fe::DataType_t::BFLOAT16); + + C = graph.matmul(A_after_pw_0, B, matmul_attributes); + } else { + C = graph.matmul(A, B, matmul_attributes); + } + C->set_output(true).set_data_type(fe::DataType_t::FLOAT); + + std::cout << graph << std::endl; + auto status = graph.validate(); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Dynamic shapes not supported pre 9.4"); + } + + status = graph.build_operation_graph(handle); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Kernel cache not supported pre 9.4"); + } + + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph.check_support(handle).is_good()); + + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::ALL).is_good()); + + return std::make_tuple(graph, A, B, C, Bias); + }; + + // Run cudnn graph + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + for (int idx_shape = 0; idx_shape < matmul_shapes_count; idx_shape++) { + auto [graph, A, B, C, Bias] = build_new_graph(handle, idx_shape); + // Initialize input tensors + Surface A_gpu(max_a_volume, false); + Surface B_gpu(max_b_volume, false); + Surface C_gpu(max_c_volume, false); + Surface Bias_gpu(max_bias_volume, false); + Surface workspace(graph.get_workspace_size(), false); + + std::unordered_map, void*> variant_pack; + if (use_bias) { + variant_pack = {{A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}, {Bias, Bias_gpu.devPtr}}; + } else { + variant_pack = {{A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; + } + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + } + + checkCudnnErr(cudnnDestroy(handle)); +} + +TEST_CASE("Matmul dynamic shape", "[matmul][graph][dynamic_shape]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + matmul_dynamic_shapes(false, false); // Matmul dynamic shape, no abs or bias +} + +TEST_CASE("Abs + Matmul dynamic shape", "[matmul][graph][dynamic_shape]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + matmul_dynamic_shapes(true, false); // Matmul with abs +} + +TEST_CASE("Bias + Matmul dynamic shape", "[matmul][graph][dynamic_shape]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + matmul_dynamic_shapes(false, true); // Matmul with bias +} + TEST_CASE("Matmul", "[matmul][graph]") { if (is_arch_supported_by_cudnn() == false) { - SKIP("Architecture is not supported by currend cudnn version"); + SKIP("Architecture is not supported by current cudnn version"); } namespace fe = cudnn_frontend; @@ -82,7 +250,10 @@ TEST_CASE("Matmul", "[matmul][graph]") { // Run cudnn graph Surface C_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -139,6 +310,7 @@ TEST_CASE("Abs + Matmul", "[matmul][graph]") { checkCudnnErr(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); REQUIRE(graph.check_support(handle).is_good()); @@ -147,7 +319,10 @@ TEST_CASE("Abs + Matmul", "[matmul][graph]") { // Run cudnn graph Surface C_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -364,7 +539,9 @@ TEST_CASE("Matmul SBR Graph", "[matmul][graph]") { auto [graph, A, B, bias, scale, O] = lookup_cache_or_build_graph( handle, x_tensor.devPtr, w_tensor.devPtr, s_tensor.devPtr, b_tensor.devPtr, y_tensor.devPtr); - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); std::unordered_map, void*> variant_pack = {{A, x_tensor.devPtr}, {B, w_tensor.devPtr}, @@ -429,7 +606,10 @@ TEST_CASE("Matmul with restricted shared memory", "[matmul][graph]") { // Run cudnn graph Surface C_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); diff --git a/samples/cpp/matmul/mixed_matmul.cpp b/samples/cpp/matmul/mixed_matmul.cpp index 6a72b679..956f88f5 100644 --- a/samples/cpp/matmul/mixed_matmul.cpp +++ b/samples/cpp/matmul/mixed_matmul.cpp @@ -96,7 +96,10 @@ TEST_CASE("Mixed Precision Matmul", "[matmul][graph]") { //// Run cudnn graph // note this is a bf16 tensor, but half is used just for memory allocation Surface C_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; diff --git a/samples/cpp/misc/autotuning.cpp b/samples/cpp/misc/autotuning.cpp index 4e52e11b..bd61ac1c 100644 --- a/samples/cpp/misc/autotuning.cpp +++ b/samples/cpp/misc/autotuning.cpp @@ -149,7 +149,9 @@ TEST_CASE("Matmul autotuning", "[matmul][graph][autotuning]") { auto candidate_index = autotune(); - std::cout << "Successful candidate is at index " << candidate_index << std::endl; + std::string name; + REQUIRE(graph.get_plan_name_at_index(candidate_index, name).is_good()); + std::cout << "Successful candidate " << name << " is at index " << candidate_index << std::endl; REQUIRE(graph.build_plan_at_index(handle, candidate_index).is_good()); diff --git a/samples/cpp/misc/pointwise.cpp b/samples/cpp/misc/pointwise.cpp index b3b1e052..e3801b18 100644 --- a/samples/cpp/misc/pointwise.cpp +++ b/samples/cpp/misc/pointwise.cpp @@ -51,7 +51,10 @@ TEST_CASE("Reduction", "[reduction]") { Surface C_gpu(n * n * n * n, false); std::unordered_map, void*> variant_pack = {{A, A_gpu.devPtr}, {C, C_gpu.devPtr}}; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudnnErr(cudnnDestroy(handle)); } @@ -85,7 +88,9 @@ TEST_CASE("Fused scalar", "[scalar][graph]") { std::unordered_map, void*> variant_pack = {{A, A_gpu.devPtr}, {C, C_gpu.devPtr}}; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -143,7 +148,10 @@ TEST_CASE("Fused Amax Reduction and type conversion", "[reduction]") { std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {scale, scale_gpu.devPtr}, {amax, amax_gpu.devPtr}, {C, C_gpu.devPtr}}; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudnnErr(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/misc/resample.cpp b/samples/cpp/misc/resample.cpp index a13f065d..ac3acb9b 100644 --- a/samples/cpp/misc/resample.cpp +++ b/samples/cpp/misc/resample.cpp @@ -69,7 +69,10 @@ TEST_CASE("Resample Max Pooling NHWC Inference", "[resample][pooling][max][graph Surface Y_gpu(N * H * W * C, false); std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, {Y, Y_gpu.devPtr}}; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudnnErr(cudnnDestroy(handle)); @@ -129,7 +132,10 @@ TEST_CASE("Resample Max Pooling NHWC Training", "[resample][pooling][max][graph] Surface Index_gpu(N * H * W * C / 8, false); std::unordered_map, void*> variant_pack = { {X, X_gpu.devPtr}, {Y, Y_gpu.devPtr}, {Index, Index_gpu.devPtr}}; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudnnErr(cudnnDestroy(handle)); @@ -180,7 +186,10 @@ TEST_CASE("Resample Avg Pooling", "[resample][pooling][average][graph]") { Surface Y_gpu(N * H * W * C, false); std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, {Y, Y_gpu.devPtr}}; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudnnErr(cudnnDestroy(handle)); diff --git a/samples/cpp/misc/serialization.cpp b/samples/cpp/misc/serialization.cpp index 9484027f..97c885ee 100644 --- a/samples/cpp/misc/serialization.cpp +++ b/samples/cpp/misc/serialization.cpp @@ -168,13 +168,19 @@ TEST_CASE("CSBR Graph with serialization", "[conv][graph][serialization]") { // Deserialize the graph and execute auto graph = deserialize(handle, serialize_data); + cudnn_frontend::graph::Tensor_attributes tensor_attr; + auto result = graph->query_tensor_attributes_of_uid(x_tensor, tensor_attr); + REQUIRE(result.is_good()); + Surface x_device_memory(n * c * h * w, false); Surface w_device_memory(k * c * r * s, false); Surface s_device_memory(k, false); Surface b_device_memory(k, false); Surface y_device_memory(n * k * h * w, false); // Should be p, q. - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); std::unordered_map variant_pack = {{x_tensor, x_device_memory.devPtr}, {w_tensor, w_device_memory.devPtr}, @@ -395,9 +401,11 @@ TEST_CASE("SDPA Graph with serialization", "[sdpa][graph][serialization]") { Surface dropoutSeed(scaleSize, false, seed_value); Surface dropoutOffset(scaleSize, false, (int32_t)1); - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); - std::cout << "Graph requires workspace " << graph->get_workspace_size() << std::endl; + std::cout << "Graph requires workspace " << workspace_size << std::endl; std::unordered_map variant_pack = {{uid_Q, devPtrQ}, {uid_K, devPtrK}, diff --git a/samples/cpp/misc/slice.cpp b/samples/cpp/misc/slice.cpp index 6a78f326..e35ff458 100644 --- a/samples/cpp/misc/slice.cpp +++ b/samples/cpp/misc/slice.cpp @@ -80,7 +80,9 @@ TEST_CASE("Slice gemm", "[slice][gemm][graph][fusion]") { Surface C_gpu(B * M * N, false); std::unordered_map variant_pack = { {a_uid, A_gpu.devPtr}, {b_uid, B_gpu.devPtr}, {c_uid, C_gpu.devPtr}}; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); fe::graph::Graph graph2; REQUIRE(graph2.deserialize(handle, serialized_data).is_good()); diff --git a/samples/cpp/misc/sm_carveout.cpp b/samples/cpp/misc/sm_carveout.cpp index 1176f1ba..e6464170 100644 --- a/samples/cpp/misc/sm_carveout.cpp +++ b/samples/cpp/misc/sm_carveout.cpp @@ -127,7 +127,10 @@ TEST_CASE("SGBN with SM carveout", "[batchnorm][graph][sm_carveout]") { Surface Peer_stats_0_tensor(2 * 4 * c, false, true); Surface Peer_stats_1_tensor(2 * 4 * c, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {mean, Mean_tensor.devPtr}, diff --git a/samples/cpp/norm/batchnorm.cpp b/samples/cpp/norm/batchnorm.cpp index 18c882ee..2c4f7a4c 100644 --- a/samples/cpp/norm/batchnorm.cpp +++ b/samples/cpp/norm/batchnorm.cpp @@ -107,7 +107,10 @@ TEST_CASE("BN Finalize Graph", "[batchnorm][graph]") { float MOMENTUM_scalar = 0.001f; int64_t nhw = 64; - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {sum, Sum_tensor.devPtr}, {sq_sum, Sq_sum_tensor.devPtr}, @@ -249,7 +252,10 @@ TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") { Surface Peer_stats_0_tensor(2 * 4 * 32, false, true); Surface Peer_stats_1_tensor(2 * 4 * 32, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {mean, Mean_tensor.devPtr}, @@ -369,7 +375,10 @@ TEST_CASE("DBN Add Relu Graph", "[BN][graph][backward]") { Surface Peer_stats_0_tensor(2 * 4 * 32, false, true); Surface Peer_stats_1_tensor(2 * 4 * 32, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {input_mask, Mask_tensor.devPtr}, @@ -474,7 +483,10 @@ TEST_CASE("BN_inference DRelu DBN Graph", "[Batchnorm][graph][backward]") { Surface Dbias_tensor(32, false); Surface DX_tensor(4 * 32 * 16 * 16, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {BN_X, BN_X_tensor.devPtr}, {DY, DY_tensor.devPtr}, diff --git a/samples/cpp/norm/layernorm.cpp b/samples/cpp/norm/layernorm.cpp index 2cd5adf6..3446f537 100644 --- a/samples/cpp/norm/layernorm.cpp +++ b/samples/cpp/norm/layernorm.cpp @@ -92,7 +92,10 @@ TEST_CASE("LayerNorm Training", "[layernorm][graph]") { float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {mean, Mean_tensor.devPtr}, @@ -172,7 +175,10 @@ TEST_CASE("LayerNorm Inference", "[layernorm][graph]") { float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {scale, Scale_tensor.devPtr}, @@ -255,7 +261,10 @@ TEST_CASE("LayerNorm Backward", "[layernorm][graph]") { Surface Dbias_tensor(hidden_size, false); Surface DX_tensor(batch_size * seq_length * hidden_size, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {DY, DY_tensor.devPtr}, diff --git a/samples/cpp/norm/rmsnorm.cpp b/samples/cpp/norm/rmsnorm.cpp index 3b8ef52f..55871ddd 100644 --- a/samples/cpp/norm/rmsnorm.cpp +++ b/samples/cpp/norm/rmsnorm.cpp @@ -83,7 +83,10 @@ TEST_CASE("RmsNorm Training", "[rmsnorm][graph]") { float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {inv_variance, Var_tensor.devPtr}, @@ -161,7 +164,10 @@ TEST_CASE("RmsNorm Inference", "[rmsnorm][graph]") { float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {scale, Scale_tensor.devPtr}, @@ -239,7 +245,10 @@ TEST_CASE("RmsNorm Backward", "[rmsnorm][graph]") { Surface Dbias_tensor(hidden_size, false); Surface DX_tensor(batch_size * seq_length * hidden_size, false); - Surface workspace(graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + std::unordered_map, void*> variant_pack = { {X, X_tensor.devPtr}, {DY, DY_tensor.devPtr}, diff --git a/samples/cpp/sdpa/fp16_bwd.cpp b/samples/cpp/sdpa/fp16_bwd.cpp index 857595d3..6168cf64 100644 --- a/samples/cpp/sdpa/fp16_bwd.cpp +++ b/samples/cpp/sdpa/fp16_bwd.cpp @@ -42,8 +42,9 @@ This example shows how to construct a sdpa backward graph-> #define O_UID 4 #define STATS_UID 5 #define BIAS_UID 6 -#define SEQ_LEN_Q_UID 7 -#define SEQ_LEN_KV_UID 8 +#define DBIAS_UID 7 +#define SEQ_LEN_Q_UID 8 +#define SEQ_LEN_KV_UID 9 #define DO_UID 101 #define DQ_UID 102 @@ -128,6 +129,13 @@ create_sdpa_backward_graph(int64_t const b, .set_dim({b, 1, s_q, s_kv}) .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); sdpa_options.set_bias(bias); + + auto dbias = graph->tensor(fe::graph::Tensor_attributes() + .set_name("dbias") + .set_uid(DBIAS_UID) + .set_dim({1, h_q, s_q, s_kv}) + .set_stride({s_q * s_kv * h_q, s_q * s_kv, s_kv, 1})); + sdpa_options.set_dbias(dbias); } // If padding mask is enabled, set sequence lengths @@ -182,7 +190,7 @@ TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { bool causal_mask = true; bool padding_mask = (cudnnGetVersion() >= 8903); bool alibi_mask = (cudnnGetVersion() >= 8904); - bool has_attn_bias = (cudnnGetVersion() >= 8903); + bool has_attn_bias = (cudnnGetVersion() >= 90500); if (cudnnGetVersion() < 8903) { SKIP("Test requires cudnn 8.9.3 or above"); @@ -223,7 +231,8 @@ TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { Surface dK_tensor(b * h_k * s_kv * d_qk, false); Surface dV_tensor(b * h_v * s_kv * d_v, false); - Surface bias_tensor(b * 1 * s_q * s_kv, false); + Surface bias_tensor(1 * h_q * s_q * s_kv, false); + Surface dbias_tensor(1 * h_q * s_q * s_kv, false); // Create variant pack with input and output tensors std::unordered_map variant_pack = {// inputs @@ -240,7 +249,8 @@ TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { // If attention bias is provided, add it to the variant pack if (has_attn_bias) { - variant_pack[BIAS_UID] = bias_tensor.devPtr; + variant_pack[BIAS_UID] = bias_tensor.devPtr; + variant_pack[DBIAS_UID] = dbias_tensor.devPtr; } // If padding mask is enabled, add sequence lengths to the variant pack @@ -265,7 +275,10 @@ TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { } // Allocate workspace - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); diff --git a/samples/cpp/sdpa/fp16_cached.cpp b/samples/cpp/sdpa/fp16_cached.cpp index 570dcc69..10711180 100644 --- a/samples/cpp/sdpa/fp16_cached.cpp +++ b/samples/cpp/sdpa/fp16_cached.cpp @@ -146,7 +146,10 @@ TEST_CASE("Cached sdpa", "[graph][sdpa][flash]") { {O_UID, o_tensor.devPtr}, {STATS_UID, stats_tensor.devPtr}}; - Surface fwd_workspace(fwd_graph2->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(fwd_graph2->get_workspace_size(workspace_size).is_good()); + Surface fwd_workspace(workspace_size, false); + REQUIRE(fwd_graph2->execute(handle, variant_pack, fwd_workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); @@ -166,7 +169,10 @@ TEST_CASE("Cached sdpa", "[graph][sdpa][flash]") { {DQ_UID, dQ_tensor.devPtr}, {DK_UID, dK_tensor.devPtr}, {DV_UID, dV_tensor.devPtr}}; - Surface bwd_workspace(bwd_graph2->get_workspace_size(), false); + + REQUIRE(bwd_graph2->get_workspace_size(workspace_size).is_good()); + Surface bwd_workspace(workspace_size, false); + REQUIRE(bwd_graph2->execute(handle, variant_pack, bwd_workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); diff --git a/samples/cpp/sdpa/fp16_fwd.cpp b/samples/cpp/sdpa/fp16_fwd.cpp index d8b6f24a..66344025 100644 --- a/samples/cpp/sdpa/fp16_fwd.cpp +++ b/samples/cpp/sdpa/fp16_fwd.cpp @@ -210,7 +210,10 @@ TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { variant_pack[STATS_UID] = statsTensor.devPtr; } - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); diff --git a/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp b/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp index 86179e2d..5c70151a 100644 --- a/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp +++ b/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp @@ -178,7 +178,10 @@ TEST_CASE("Toy sdpa forward with dropout", "[graph][sdpa][flash][forward]") { variant_pack[STATS_UID] = statsTensor.devPtr; } - Surface workspace(graph->get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); diff --git a/samples/cpp/sdpa/fp8_bwd.cpp b/samples/cpp/sdpa/fp8_bwd.cpp index f1fb3a4c..487aed2d 100644 --- a/samples/cpp/sdpa/fp8_bwd.cpp +++ b/samples/cpp/sdpa/fp8_bwd.cpp @@ -124,9 +124,9 @@ TEST_CASE("sdpa_fp8_bprop", "[graph][sdpa][fp8][backward]") { scale_dP, sdpa_fp8_backwards_options); - dQ->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); - dK->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); - dV->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + dQ->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(Q_dQ_strides); + dK->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(Q_dQ_strides); + dV->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(Q_dQ_strides); Amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_stride({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); Amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_stride({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); Amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_stride({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); @@ -214,7 +214,10 @@ TEST_CASE("sdpa_fp8_bprop", "[graph][sdpa][fp8][backward]") { {Amax_dV, AMax_dV_Tensor.devPtr}, {Amax_dP, AMax_dP_Tensor.devPtr}}; - Surface workspace(mha_graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(mha_graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); @@ -382,7 +385,10 @@ TEST_CASE("sdpa_fp8_gqa_bprop", "[graph][sdpa][fp8][backward]") { {amax_dV, amax_dV_gpu.devPtr}, {amax_dP, amax_dP_gpu.devPtr}}; - Surface workspace(mha_graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(mha_graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); diff --git a/samples/cpp/sdpa/fp8_fwd.cpp b/samples/cpp/sdpa/fp8_fwd.cpp index 3f73f8c7..0426d0a3 100644 --- a/samples/cpp/sdpa/fp8_fwd.cpp +++ b/samples/cpp/sdpa/fp8_fwd.cpp @@ -146,7 +146,10 @@ TEST_CASE("sdpa_fp8_fprop", "[graph][sdpa][fp8][forward]") { variant_pack[Stats] = stats_tensor.devPtr; } - Surface workspace(mha_graph.get_workspace_size(), false); + int64_t workspace_size; + REQUIRE(mha_graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); checkCudaErr(cudaDeviceSynchronize()); diff --git a/samples/python/25_batchnorm.ipynb b/samples/python/25_batchnorm.ipynb new file mode 100644 index 00000000..b2bcee80 --- /dev/null +++ b/samples/python/25_batchnorm.ipynb @@ -0,0 +1,321 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook shows how to compute a batchnorm forward operation using cuDNN.\n", + "\n", + "$$\\text{BatchNorm}(x) = \\frac{x-\\mu}{\\sqrt{\\sigma^2 + \\epsilon}}\\cdot\\gamma+\\beta$$\n", + "\n", + "Where $\\mu = E[x]$ and $\\sigma^2 = Var[x]$ are taken over all inputs in a channel." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites and Setup\n", + "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# get_ipython().system('nvidia-smi')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If running on Colab, you will need to install the cudnn python interface." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", + "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### General Setup\n", + "Create a cudnn handle, which is a per device handle used to initialize cudnn context." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running with cudnn backend version: 90400\n" + ] + } + ], + "source": [ + "import cudnn\n", + "import torch\n", + "import sys\n", + "\n", + "torch.manual_seed(1)\n", + "handle = cudnn.create_handle()\n", + "\n", + "print(\"Running with cudnn backend version:\", cudnn.backend_version())\n", + "\n", + "assert torch.cuda.is_available()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batchnorm Training Forward" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# batch size, channel size, height, width\n", + "n, c, h, w = 4, 16, 56, 56\n", + "input_type = torch.float16\n", + "\n", + "# Epsilon is a small number to prevent division by 0.\n", + "epsilon_value = 1e-3\n", + "# Momentum value is used in computing running stats during training where\n", + "# running_mean_next = (1 - momentum) * running_mean + momentum * local_mean\n", + "momentum_value = 1e-1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create input and output tensor buffers in PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# input tensors\n", + "x_gpu = torch.randn(n, c, h, w, dtype=input_type, device=\"cuda\")\n", + "x_gpu = x_gpu.to(memory_format=torch.channels_last)\n", + "scale_gpu = torch.randn(1, c, 1, 1, device=\"cuda\")\n", + "bias_gpu = torch.randn_like(scale_gpu)\n", + "running_mean_gpu = torch.randn_like(scale_gpu)\n", + "running_var_gpu = torch.randn_like(scale_gpu)\n", + "\n", + "comparison_gpu = torch.zeros_like(x_gpu, dtype=input_type, device=\"cuda\")\n", + "\n", + "epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value)\n", + "momentum_cpu = torch.full((1, 1, 1, 1), momentum_value)\n", + "\n", + "# output tensors\n", + "saved_mean_gpu = torch.empty_like(running_mean_gpu, device=\"cuda\")\n", + "saved_inv_var_gpu = torch.empty_like(running_var_gpu, device=\"cuda\")\n", + "y_gpu = torch.empty_like(x_gpu, dtype=input_type, device=\"cuda\")\n", + "mask_gpu = torch.empty_like(x_gpu, dtype=torch.bool, device=\"cuda\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create cuDNN graph" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "graph = cudnn.pygraph(\n", + " io_data_type=cudnn.data_type.HALF,\n", + " intermediate_data_type=cudnn.data_type.FLOAT,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + " handle=handle,\n", + ")\n", + "\n", + "x = graph.tensor_like(x_gpu)\n", + "scale = graph.tensor_like(scale_gpu)\n", + "bias = graph.tensor_like(bias_gpu)\n", + "\n", + "in_running_mean = graph.tensor_like(running_mean_gpu)\n", + "in_running_var = graph.tensor_like(running_var_gpu)\n", + "epsilon = graph.tensor_like(epsilon_cpu)\n", + "momentum = graph.tensor_like(momentum_cpu)\n", + "comparison = graph.tensor_like(x_gpu)\n", + "\n", + "y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var = (\n", + " graph.batchnorm(\n", + " name=\"BN\",\n", + " input=x,\n", + " scale=scale,\n", + " bias=bias,\n", + " in_running_mean=in_running_mean,\n", + " in_running_var=in_running_var,\n", + " epsilon=epsilon,\n", + " momentum=momentum,\n", + " )\n", + ")\n", + "y = graph.relu(name=\"relu\", input=y_before_relu)\n", + "mask = graph.cmp_gt(name=\"cmp\", input=y, comparison=comparison)\n", + "\n", + "y.set_output(True)\n", + "saved_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n", + "saved_inv_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n", + "out_running_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n", + "out_running_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n", + "mask.set_output(True).set_data_type(cudnn.data_type.BOOLEAN)\n", + "pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Build the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "graph.validate()\n", + "graph.build_operation_graph()\n", + "graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n", + "graph.check_support()\n", + "graph.build_plans()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Execute the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "variant_pack = {\n", + " x: x_gpu,\n", + " scale: scale_gpu,\n", + " bias: bias_gpu,\n", + " in_running_mean: running_mean_gpu,\n", + " in_running_var: running_var_gpu,\n", + " epsilon: epsilon_cpu,\n", + " momentum: momentum_cpu,\n", + " out_running_mean: running_mean_gpu,\n", + " out_running_var: running_var_gpu,\n", + " saved_mean: saved_mean_gpu,\n", + " saved_inv_var: saved_inv_var_gpu,\n", + " y: y_gpu,\n", + " comparison: comparison_gpu,\n", + " mask: mask_gpu,\n", + "}\n", + "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n", + "graph.execute(\n", + " variant_pack,\n", + " workspace,\n", + " handle=handle,\n", + ")\n", + "torch.cuda.synchronize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test cuDNN's output against PyTorch's and check correctness" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "x_ref = x_gpu.clone().float()\n", + "running_mean_ref = running_mean_gpu.clone().float()\n", + "running_var_ref = running_var_gpu.clone().float()\n", + "\n", + "y_before_relu_ref = torch.nn.functional.batch_norm(\n", + " x_ref,\n", + " running_mean_ref, # running_mean is both input and output\n", + " running_var_ref, # running_var is both input and output\n", + " weight=scale_gpu,\n", + " bias=bias_gpu,\n", + " training=True,\n", + " momentum=momentum_cpu.item(),\n", + " eps=epsilon_cpu.item(),\n", + ")\n", + "\n", + "mean_ref = torch.mean(x_ref, dim=(0, 2, 3), keepdim=True)\n", + "inv_var_ref = torch.var(x_ref, dim=(0, 2, 3), keepdim=True)\n", + "inv_var_ref = torch.rsqrt(inv_var_ref + epsilon_value)\n", + "y_ref = torch.relu(y_before_relu_ref)\n", + "mask_ref = y_ref > 0\n", + "\n", + "torch.testing.assert_close(y_ref, y_gpu.float(), atol=1e-3, rtol=1e-3)\n", + "torch.testing.assert_close(mean_ref, saved_mean_gpu.float(), atol=1e-3, rtol=1e-3)\n", + "torch.testing.assert_close(inv_var_ref, saved_inv_var_gpu.float(), atol=1e-3, rtol=1e-3)\n", + "# torch.testing.assert_close(mask_ref, mask_gpu.float(), atol=1e-3, rtol=1e-3)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/python_fe/test_batchnorm.py b/test/python_fe/test_batchnorm.py index 38582171..6212f32c 100644 --- a/test/python_fe/test_batchnorm.py +++ b/test/python_fe/test_batchnorm.py @@ -6,18 +6,19 @@ from test_utils import torch_fork_set_rng -class SGBN(torch.nn.Module): - def forward(self, input, running_mean, running_var, weight, bias, eps, momentum): - return torch.nn.functional.batch_norm( - input, - running_mean, - running_var, - weight=weight, - bias=bias, - training=True, - momentum=momentum, - eps=eps, - ) +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") @pytest.mark.skipif( @@ -26,82 +27,56 @@ def forward(self, input, running_mean, running_var, weight, bias, eps, momentum) ) @torch_fork_set_rng(seed=0) def test_bn_relu_with_mask(cudnn_handle): + n, c, h, w = 4, 16, 56, 56 + input_type = torch.float16 - N, C, H, W = 4, 16, 56, 56 - x_gpu = torch.randn( - N, C, H, W, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - scale_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - bias_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - running_mean_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - running_var_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) + epsilon_value = 1e-3 + momentum_value = 1e-1 - epsilon_value = 1e-03 - epsilon_cpu = torch.full( - (1, 1, 1, 1), - epsilon_value, - requires_grad=False, - device="cpu", - dtype=torch.float32, - ) - momentum_cpu = torch.full( - (1, 1, 1, 1), 0.1, requires_grad=False, device="cpu", dtype=torch.float32 - ) + # input tensors + x_gpu = torch.randn(n, c, h, w, dtype=input_type, device="cuda") + x_gpu = x_gpu.to(memory_format=torch.channels_last) + scale_gpu = torch.randn(1, c, 1, 1, device="cuda") + bias_gpu = torch.randn_like(scale_gpu) + running_mean_gpu = torch.randn_like(scale_gpu) + running_var_gpu = torch.randn_like(scale_gpu) + + comparison_gpu = torch.zeros_like(x_gpu, dtype=input_type, device="cuda") + + epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value) + momentum_cpu = torch.full((1, 1, 1, 1), momentum_value) + # output tensors + saved_mean_gpu = torch.empty_like(running_mean_gpu, device="cuda") + saved_inv_var_gpu = torch.empty_like(running_var_gpu, device="cuda") + y_gpu = torch.empty_like(x_gpu, dtype=input_type, device="cuda") + mask_gpu = torch.empty_like(x_gpu, dtype=torch.bool, device="cuda") + + # cudnn graph stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - # Cudnn code graph = cudnn.pygraph( - io_data_type=cudnn.data_type.FLOAT, + io_data_type=convert_to_cudnn_type(input_type), intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, handle=cudnn_handle, ) - X = graph.tensor( - name="X", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype - ) - scale = graph.tensor(name="scale", dim=scale_gpu.size(), stride=scale_gpu.stride()) - bias = graph.tensor(name="bias", dim=bias_gpu.size(), stride=bias_gpu.stride()) - in_running_mean = graph.tensor( - name="in_running_mean", - dim=running_mean_gpu.size(), - stride=running_mean_gpu.stride(), - ) - in_running_var = graph.tensor( - name="in_running_var", - dim=running_var_gpu.size(), - stride=running_var_gpu.stride(), - ) - epsilon = graph.tensor( - name="epsilon", - dim=epsilon_cpu.size(), - stride=epsilon_cpu.stride(), - is_pass_by_value=True, - ) - momentum = graph.tensor( - name="momentum", - dim=momentum_cpu.size(), - stride=momentum_cpu.stride(), - is_pass_by_value=True, - ) - comparison = graph.tensor( - name="zeros", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype - ) + x = graph.tensor_like(x_gpu) + scale = graph.tensor_like(scale_gpu) + bias = graph.tensor_like(bias_gpu) + + in_running_mean = graph.tensor_like(running_mean_gpu) + in_running_var = graph.tensor_like(running_var_gpu) + epsilon = graph.tensor_like(epsilon_cpu) + momentum = graph.tensor_like(momentum_cpu) + comparison = graph.tensor_like(x_gpu) - (Y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var) = ( + y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var = ( graph.batchnorm( name="BN", - input=X, + input=x, scale=scale, bias=bias, in_running_mean=in_running_mean, @@ -110,14 +85,14 @@ def test_bn_relu_with_mask(cudnn_handle): momentum=momentum, ) ) - Y = graph.relu(name="relu", input=Y_before_relu) - Y.set_output(True).set_data_type(cudnn.data_type.HALF) + y = graph.relu(name="relu", input=y_before_relu) + mask = graph.cmp_gt(name="cmp", input=y, comparison=comparison) + + y.set_output(True) saved_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT) saved_inv_var.set_output(True).set_data_type(cudnn.data_type.FLOAT) out_running_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT) out_running_var.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - mask = graph.cmp_gt(name="cmp", input=Y, comparison=comparison) mask.set_output(True).set_data_type(cudnn.data_type.BOOLEAN) graph.validate() @@ -126,160 +101,129 @@ def test_bn_relu_with_mask(cudnn_handle): graph.check_support() graph.build_plans() - # Reference code execution - model = SGBN().eval().to("cuda") - Y_expected_before_relu = model( - x_gpu, - running_mean_gpu, - running_var_gpu, - scale_gpu, - bias_gpu, - epsilon_cpu.item(), - momentum_cpu.item(), - ) - mean_expected = x_gpu.to(torch.float32).mean(dim=(0, 2, 3), keepdim=True) - inv_var_expected = torch.rsqrt( - torch.var(x_gpu.to(torch.float32), dim=(0, 2, 3), keepdim=True) + epsilon_value - ) - Y_expected = torch.relu(Y_expected_before_relu) - mask_expected = Y_expected > 0 - # cudnn graph execution - saved_mean_actual = torch.zeros_like(scale_gpu) - saved_inv_var_actual = torch.zeros_like(scale_gpu) - Y_actual = torch.zeros_like(Y_expected) - mask_actual = torch.empty( - N, C, H, W, requires_grad=False, device="cuda", dtype=torch.bool - ).to(memory_format=torch.channels_last) - - zeros = torch.zeros_like(Y_expected) - + variant_pack = { + x: x_gpu, + scale: scale_gpu, + bias: bias_gpu, + in_running_mean: running_mean_gpu, + in_running_var: running_var_gpu, + epsilon: epsilon_cpu, + momentum: momentum_cpu, + out_running_mean: running_mean_gpu, + out_running_var: running_var_gpu, + saved_mean: saved_mean_gpu, + saved_inv_var: saved_inv_var_gpu, + y: y_gpu, + comparison: comparison_gpu, + mask: mask_gpu, + } workspace = torch.empty( graph.get_workspace_size(), device="cuda", dtype=torch.uint8 ) - graph.execute( - { - X: x_gpu, - scale: scale_gpu, - bias: bias_gpu, - in_running_mean: running_mean_gpu, - in_running_var: running_var_gpu, - epsilon: epsilon_cpu, - momentum: momentum_cpu, - out_running_mean: running_mean_gpu, - out_running_var: running_var_gpu, - saved_mean: saved_mean_actual, - saved_inv_var: saved_inv_var_actual, - Y: Y_actual, - comparison: zeros, - mask: mask_actual, - }, + variant_pack, workspace, handle=cudnn_handle, ) - - # Compare torch.cuda.synchronize() - print("Comparing outputs") - torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) - torch.testing.assert_close(mean_expected, saved_mean_actual, atol=1e-3, rtol=1e-3) - torch.testing.assert_close( - inv_var_expected, saved_inv_var_actual, atol=1e-3, rtol=1e-3 + + # reference computation + x_ref = x_gpu.clone().float() + running_mean_ref = running_mean_gpu.clone().float() + running_var_ref = running_var_gpu.clone().float() + + y_before_relu_ref = torch.nn.functional.batch_norm( + x_ref, + running_mean_ref, # running_mean is both input and output + running_var_ref, # running_var is both input and output + weight=scale_gpu, + bias=bias_gpu, + training=True, + momentum=momentum_cpu.item(), + eps=epsilon_cpu.item(), ) - # torch.testing.assert_close(mask_expected, mask_actual) + mean_ref = torch.mean(x_ref, dim=(0, 2, 3), keepdim=True) + inv_var_ref = torch.var(x_ref, dim=(0, 2, 3), keepdim=True) + inv_var_ref = torch.rsqrt(inv_var_ref + epsilon_value) + y_ref = torch.relu(y_before_relu_ref) + mask_ref = y_ref > 0 + + # Compare + # fmt: off + torch.testing.assert_close(y_ref, y_gpu.float(), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(mean_ref, saved_mean_gpu.float(), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(inv_var_ref, saved_inv_var_gpu.float(), atol=1e-3, rtol=1e-3) + # torch.testing.assert_close(mask_ref, mask_gpu.float(), atol=1e-3, rtol=1e-3) + # fmt: on +@pytest.mark.parametrize( + "dump_dX_dRelu", [True, False], ids=lambda p: f"dump_dX_dRelu{int(p)}" +) @pytest.mark.skipif( LooseVersion(cudnn.backend_version_string()) < "8.9", reason="DBN fusions not supported below cudnn 8.9", ) @torch_fork_set_rng(seed=0) -def test_drelu_dadd_dbn(cudnn_handle): - - # Tensors - N, C, H, W = 4, 16, 56, 56 - - x_gpu = torch.randn( - N, C, H, W, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - scale_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - mean_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - inv_variance_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - dy_gpu = torch.randn( - N, C, H, W, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - x_mask_gpu = torch.randint( - 0, 2, [N, C, H, W], requires_grad=False, device="cuda", dtype=torch.bool - ).to(memory_format=torch.channels_last) - +def test_drelu_dadd_dbn(dump_dX_dRelu, cudnn_handle): + n, c, h, w = 4, 16, 56, 56 + input_type = torch.float16 + + # input tensors + x_gpu = torch.randn(n, c, h, w, dtype=input_type, device="cuda") + x_gpu = x_gpu.to(memory_format=torch.channels_last) + x_mask_gpu = torch.randn_like(x_gpu) > 0.0 + scale_gpu = torch.randn(1, c, 1, 1, device="cuda") + mean_gpu = torch.randn_like(scale_gpu) + inv_var_gpu = torch.randn_like(scale_gpu) + dY_gpu = torch.randn_like(x_gpu) + + # output tensors + dScale_ref = torch.empty_like(scale_gpu) + dBias_ref = torch.empty_like(scale_gpu) + dX_ref = torch.empty_like(dY_gpu) + + if dump_dX_dRelu: + dX_dRelu_gpu = torch.empty_like(dY_gpu) + + # cudnn graph stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - # Cudnn code graph = cudnn.pygraph( - io_data_type=cudnn.data_type.HALF, + io_data_type=convert_to_cudnn_type(input_type), intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, handle=cudnn_handle, ) - X = graph.tensor( - name="X", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype - ) - DY = graph.tensor( - name="DY", dim=dy_gpu.size(), stride=dy_gpu.stride(), data_type=dy_gpu.dtype - ) - scale = graph.tensor( - name="scale", - dim=scale_gpu.size(), - stride=scale_gpu.stride(), - data_type=scale_gpu.dtype, - ) - mean = graph.tensor( - name="mean", - dim=mean_gpu.size(), - stride=mean_gpu.stride(), - data_type=mean_gpu.dtype, - ) - inv_variance = graph.tensor( - name="inv_variance", - dim=inv_variance_gpu.size(), - stride=inv_variance_gpu.stride(), - data_type=inv_variance_gpu.dtype, - ) - X_mask = graph.tensor( - name="X_mask", - dim=x_mask_gpu.size(), - stride=x_mask_gpu.stride(), - data_type=x_mask_gpu.dtype, - ) + x = graph.tensor_like(x_gpu) + x_mask = graph.tensor_like(x_mask_gpu) + scale = graph.tensor_like(scale_gpu) + mean = graph.tensor_like(mean_gpu) + inv_var = graph.tensor_like(inv_var_gpu) + dY = graph.tensor_like(dY_gpu) - DX_drelu = graph.scale(name="drelu", input=DY, scale=X_mask) + dX_drelu = graph.scale(name="drelu", input=dY, scale=x_mask) + dX_drelu.set_data_type(cudnn.data_type.HALF) - # NOTE: Toggle DADD output to dump to gmem - should_dump_dx_drelu = False - DX_drelu.set_output(should_dump_dx_drelu).set_data_type(cudnn.data_type.HALF) + if dump_dX_dRelu: + dX_drelu.set_output(True) - (DX, DScale, DBias) = graph.batchnorm_backward( + dX, dScale, dBias = graph.batchnorm_backward( name="DBN", - grad=DX_drelu, - input=X, + grad=dX_drelu, + input=x, scale=scale, mean=mean, - inv_variance=inv_variance, + inv_variance=inv_var, ) - DX.set_output(True) - DScale.set_output(True).set_data_type(cudnn.data_type.FLOAT) - DBias.set_output(True).set_data_type(cudnn.data_type.FLOAT) + dX.set_output(True).set_data_type(cudnn.data_type.HALF) + dScale.set_output(True).set_data_type(cudnn.data_type.FLOAT) + dBias.set_output(True).set_data_type(cudnn.data_type.FLOAT) graph.validate() graph.build_operation_graph() @@ -287,29 +231,26 @@ def test_drelu_dadd_dbn(cudnn_handle): graph.check_support() graph.build_plans() - DScale_actual = torch.zeros_like(scale_gpu) - DBias_actual = torch.zeros_like(scale_gpu) - DX_actual = torch.zeros_like(dy_gpu) + variant_pack = { + x: x_gpu, + x_mask: x_mask_gpu, + dY: dY_gpu, + scale: scale_gpu, + mean: mean_gpu, + inv_var: inv_var_gpu, + dX: dX_ref, + dScale: dScale_ref, + dBias: dBias_ref, + } + if dump_dX_dRelu: + variant_pack[dX_drelu] = dX_dRelu_gpu workspace = torch.empty( graph.get_workspace_size(), device="cuda", dtype=torch.uint8 ) - device_buffers = { - X: x_gpu, - X_mask: x_mask_gpu, - DY: dy_gpu, - scale: scale_gpu, - mean: mean_gpu, - inv_variance: inv_variance_gpu, - DX: DX_actual, - DScale: DScale_actual, - DBias: DBias_actual, - } - if should_dump_dx_drelu is True: - DX_drelu_actual = torch.zeros_like(dy_gpu) - device_buffers[DX_drelu] = DX_drelu_actual - graph.execute(device_buffers, workspace, handle=cudnn_handle) + graph.execute(variant_pack, workspace, handle=cudnn_handle) + torch.cuda.synchronize() @pytest.mark.skipif( @@ -318,33 +259,27 @@ def test_drelu_dadd_dbn(cudnn_handle): ) @torch_fork_set_rng(seed=0) def test_bn_infer_drelu_dbn(cudnn_handle): - - # Tensors - N, C, H, W = 4, 16, 56, 56 - - bn_x_gpu = torch.randn( - N, C, H, W, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - scale_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - bias_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - mean_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - inv_variance_gpu = torch.randn( - 1, C, 1, 1, requires_grad=False, device="cuda", dtype=torch.float32 - ) - dy_gpu = torch.randn( - N, C, H, W, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - + n, c, h, w = 4, 16, 56, 56 + input_type = torch.float16 + + # input tensors + x_gpu = torch.randn(n, c, h, w, dtype=input_type, device="cuda") + x_gpu = x_gpu.to(memory_format=torch.channels_last) + scale_gpu = torch.randn(1, c, 1, 1, device="cuda") + bias_gpu = torch.randn_like(scale_gpu) + mean_gpu = torch.randn_like(scale_gpu) + inv_var_gpu = torch.randn_like(scale_gpu) + dY_gpu = torch.randn_like(x_gpu) + + # output tensors + dScale_gpu = torch.empty_like(scale_gpu) + dBias_gpu = torch.empty_like(scale_gpu) + dX_gpu = torch.empty_like(x_gpu) + + # cudnn graph stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - # Cudnn code graph = cudnn.pygraph( io_data_type=cudnn.data_type.HALF, intermediate_data_type=cudnn.data_type.FLOAT, @@ -352,15 +287,14 @@ def test_bn_infer_drelu_dbn(cudnn_handle): handle=cudnn_handle, ) - # Bool type is not supported by dlpack - BN_X = graph.tensor( - name="BN_X", - dim=bn_x_gpu.size(), - stride=bn_x_gpu.stride(), - data_type=bn_x_gpu.dtype, + x = graph.tensor( + name="x", + dim=x_gpu.size(), + stride=x_gpu.stride(), + data_type=x_gpu.dtype, ) - DY = graph.tensor( - name="DY", dim=dy_gpu.size(), stride=dy_gpu.stride(), data_type=dy_gpu.dtype + dY = graph.tensor( + name="dY", dim=dY_gpu.size(), stride=dY_gpu.stride(), data_type=dY_gpu.dtype ) scale = graph.tensor( name="scale", @@ -382,31 +316,31 @@ def test_bn_infer_drelu_dbn(cudnn_handle): ) inv_variance = graph.tensor( name="inv_variance", - dim=inv_variance_gpu.size(), - stride=inv_variance_gpu.stride(), - data_type=inv_variance_gpu.dtype, + dim=inv_var_gpu.size(), + stride=inv_var_gpu.stride(), + data_type=inv_var_gpu.dtype, ) - BN_Y = graph.batchnorm_inference( - input=BN_X, mean=mean, inv_variance=inv_variance, scale=scale, bias=bias + y = graph.batchnorm_inference( + input=x, mean=mean, inv_variance=inv_variance, scale=scale, bias=bias ) - DX_drelu = graph.relu_backward(loss=DY, input=BN_Y) + dX_dRelu = graph.relu_backward(loss=dY, input=y) - DX_drelu.set_data_type(cudnn.data_type.HALF) + dX_dRelu.set_data_type(cudnn.data_type.HALF) - (DX, DScale, DBias) = graph.batchnorm_backward( + dX, dScale, dBias = graph.batchnorm_backward( name="DBN", - grad=DX_drelu, - input=BN_X, + grad=dX_dRelu, + input=x, scale=scale, mean=mean, inv_variance=inv_variance, ) - DX.set_output(True) - DScale.set_output(True).set_data_type(cudnn.data_type.FLOAT) - DBias.set_output(True).set_data_type(cudnn.data_type.FLOAT) + dX.set_output(True) + dScale.set_output(True).set_data_type(cudnn.data_type.FLOAT) + dBias.set_output(True).set_data_type(cudnn.data_type.FLOAT) graph.validate() graph.build_operation_graph() @@ -414,29 +348,25 @@ def test_bn_infer_drelu_dbn(cudnn_handle): graph.check_support() graph.build_plans() - DScale_actual = torch.zeros_like(scale_gpu) - DBias_actual = torch.zeros_like(scale_gpu) - DX_actual = torch.zeros_like(dy_gpu) + variant_pack = { + x: x_gpu, + dY: dY_gpu, + scale: scale_gpu, + bias: bias_gpu, + mean: mean_gpu, + inv_variance: inv_var_gpu, + dX: dX_gpu, + dScale: dScale_gpu, + dBias: dBias_gpu, + } workspace = torch.empty( graph.get_workspace_size(), device="cuda", dtype=torch.uint8 ) - device_buffers = { - BN_X: bn_x_gpu, - DY: dy_gpu, - scale: scale_gpu, - bias: bias_gpu, - mean: mean_gpu, - inv_variance: inv_variance_gpu, - DX: DX_actual, - DScale: DScale_actual, - DBias: DBias_actual, - } - graph.execute(device_buffers, workspace, handle=cudnn_handle) + graph.execute(variant_pack, workspace, handle=cudnn_handle) + torch.cuda.synchronize() if __name__ == "__main__": - test_bn_relu_with_mask() - test_drelu_dadd_dbn() - test_bn_infer_drelu_dbn() + pytest.main([__file__]) diff --git a/test/python_fe/test_conv_bias.py b/test/python_fe/test_conv_bias.py index d5757d48..d0f1b44c 100644 --- a/test/python_fe/test_conv_bias.py +++ b/test/python_fe/test_conv_bias.py @@ -331,7 +331,7 @@ def dleaky_relu(grad: torch.Tensor, mask: torch.Tensor, negative_slope: float): @torch_fork_set_rng(seed=0) def test_conv_int8(cudnn_handle): - N, C, H, W = 1, 64, 32, 32 + N, C, H, W = 2, 64, 32, 32 K, R, S = 4, 3, 3 padding = [1, 1] stride = [1, 1] diff --git a/test/python_fe/test_layernorm.py b/test/python_fe/test_layernorm.py index 154fff67..95c3a1c7 100644 --- a/test/python_fe/test_layernorm.py +++ b/test/python_fe/test_layernorm.py @@ -128,7 +128,7 @@ def test_layernorm(param_extract, cudnn_handle): graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() - graph.build_plans(cudnn.build_plan_policy.ALL) + graph.build_plans() Y_actual = torch.empty_like(x_gpu) mean_actual = torch.empty_like(mean_expected) @@ -197,7 +197,7 @@ def test_layernorm(param_extract, cudnn_handle): bwd_graph.build_operation_graph() bwd_graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) bwd_graph.check_support() - bwd_graph.build_plans(cudnn.build_plan_policy.ALL) + bwd_graph.build_plans() DX_actual = torch.empty_like(x_gpu) DScale_actual = torch.empty_like(scale_gpu) diff --git a/test/python_fe/test_mhas.py b/test/python_fe/test_mhas.py index a564cbda..9b7f0c2d 100644 --- a/test/python_fe/test_mhas.py +++ b/test/python_fe/test_mhas.py @@ -515,8 +515,6 @@ def test_sdpa( b = int(request.config.option.mha_b) if request.config.option.mha_b != None else b s_q = int(request.config.option.mha_s_q) if request.config.option.mha_s_q != None else s_q s_kv = int(request.config.option.mha_s_kv) if request.config.option.mha_s_kv != None else s_kv - if is_sliding_window: - s_kv = s_q d_qk = int(request.config.option.mha_d_qk) if request.config.option.mha_d_qk != None else d_qk d_v = int(request.config.option.mha_d_v) if request.config.option.mha_d_v != None else d_v h_q = int(request.config.option.mha_h_q) if request.config.option.mha_h_q != None else h_q @@ -529,6 +527,9 @@ def test_sdpa( if d_qk != d_v and is_ragged and cudnn_version < "9.1": pytest.skip("d_qk != d_v is not supported with ragged offset") + if s_q > s_kv and is_sliding_window: + pytest.skip("s_q > s_kv is not supported with sliding window attention") + print("\n=============== TEST CMD TO REPRODUCE ===============") print( f"pytest {request.node.nodeid} --mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v}" @@ -823,10 +824,7 @@ def test_sdpa_backward( pytest.skip("dBias is only supported 8.9.6 onwards.") if is_bias and cudnn_version < "9" and torch.cuda.get_device_capability()[0] < 9: - pytest.skip("dBias is only supported on hopper onwards.") - - if is_bias and is_padding: - pytest.skip("dBias is not supported with padding mask") + pytest.skip("dBias is only supported on hopper before v9.") if is_alibi and not is_causal: pytest.skip("ALiBi mask is only supported with causal mask") @@ -914,11 +912,6 @@ def test_sdpa_backward( if d_qk != d_v and cudnn_version < "8.9.6": pytest.skip("d_qk != d_v is only supported on 8.9.6 onwards.") - if ((s_q % 64 != 0) or (s_kv % 64 != 0)) and is_bias: - pytest.skip( - "cudnn backend does not support bias with non-64-aligned seq_q or seq_kv." - ) - if d_qk != d_v and is_ragged and cudnn_version < "9.1": pytest.skip("d_qk != d_v is not supported with ragged offset") @@ -1320,9 +1313,6 @@ def test_sdpa_backward( dK_gpu[i, :, n:, :] = 0 dV_ref[i, :, n:, :] = 0 dV_gpu[i, :, n:, :] = 0 - if is_bias: - dBias_ref[i, :, m:, :] = 0 - dBias_ref[i, :, :, n:] = 0 torch.cuda.synchronize() diff --git a/test/unit_tests/CMakeLists.txt b/test/unit_tests/CMakeLists.txt index 1044f735..633fb5ec 100644 --- a/test/unit_tests/CMakeLists.txt +++ b/test/unit_tests/CMakeLists.txt @@ -23,6 +23,7 @@ add_executable( serialize.cpp validate.cpp version.cpp + tensor.cpp ) if (MSVC) diff --git a/test/unit_tests/serialize.cpp b/test/unit_tests/serialize.cpp index c8a3fcc0..5b28d5e6 100644 --- a/test/unit_tests/serialize.cpp +++ b/test/unit_tests/serialize.cpp @@ -142,6 +142,85 @@ TEST_CASE("Graph key", "[graph][key]") { REQUIRE(key == graph.key()); } +TEST_CASE("Graph key dynamic shape", "[graph][key][dynamic_shape]") { + namespace fe = cudnn_frontend; + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + // clang-format off + struct { + int64_t b, m, n, k; + } shapes[] = { + { 4, 16, 32, 64}, + { 8, 32, 64, 128}, + }; + // clang-format on + + constexpr int shapes_count = sizeof(shapes) / sizeof(shapes[0]); + size_t key = 0; // Save key between runs to verify that dim and stride information is deleted + + for (int idx_shape = 0; idx_shape < shapes_count; idx_shape++) { + auto b = shapes[idx_shape].b; + auto m = shapes[idx_shape].m; + auto n = shapes[idx_shape].n; + auto k = shapes[idx_shape].k; + + fe::graph::Graph graph; + graph.set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT) + .set_dynamic_shape_enabled(true); + + auto X = + graph.tensor(fe::graph::Tensor_attributes().set_name("image").set_dim({b, m, k}).set_stride({m * k, 1, m})); + auto Y = graph.tensor( + fe::graph::Tensor_attributes().set_name("filter").set_dim({b, k, n}).set_stride({n * k, 1, k})); + + fe::graph::Matmul_attributes matmul; + auto Z = graph.matmul(X, Y, matmul); + + auto scale_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::MUL); + auto S = + graph.tensor(fe::graph::Tensor_attributes().set_name("scale").set_dim({b, m, n}).set_stride({m * n, n, 1})); + auto scale_output = graph.pointwise(Z, S, scale_options); + + auto bias_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD); + auto B = + graph.tensor(fe::graph::Tensor_attributes().set_name("bias").set_dim({b, m, n}).set_stride({m * n, n, 1})); + auto bias_output = graph.pointwise(scale_output, B, bias_options); + + auto relu_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::RELU_FWD); + auto O = graph.pointwise(bias_output, relu_options); + O->set_output(true); + + cudnnHandle_t handle; + cudnnCreate(&handle); + + auto status = graph.validate(); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Dynamic shapes not supported pre 9.4"); + } + + REQUIRE(graph.build_operation_graph(handle).is_good()); + + if (!key) { + key = graph.key(); + } + + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(key == graph.key()); + + REQUIRE(graph.check_support(handle).is_good()); + REQUIRE(key == graph.key()); + + REQUIRE(graph.build_plans(handle).is_good()); + REQUIRE(key == graph.key()); + } +} + TEST_CASE("Matmul fp8 fusion", "[graph][serialize]") { namespace fe = cudnn_frontend; // matmul problem size diff --git a/test/unit_tests/tensor.cpp b/test/unit_tests/tensor.cpp new file mode 100644 index 00000000..3c374233 --- /dev/null +++ b/test/unit_tests/tensor.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ +#include + +#include + +TEST_CASE("tensor query checks", "[query_tensor_attributes_of_uid]") { + namespace fe = cudnn_frontend; + + fe::graph::Graph graph; + graph.set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + int64_t uid = 1; + std::string name = "image"; + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name(name) + .set_dim({8, 32, 16, 16}) + .set_stride({32 * 16 * 16, 1, 32 * 16, 32}) + .set_uid(uid)); + + fe::graph::Tensor_attributes t; + + REQUIRE(graph.query_tensor_attributes_of_uid(uid, t).is_good()); + + REQUIRE(t.get_name() == name); +} \ No newline at end of file