Skip to content

Commit

Permalink
# cudnn FE 1.7.0 Release notes: (#111)
Browse files Browse the repository at this point in the history
## New API

- Kernel Cache support for dynamic graphs
Added New APIs to enable kernel cache support for graphs with dynamic shapes. Please refer to [documentation](docs/dynamic_kernel_cache.md) for API details.

Added examples `Convolution fprop dynamic shape`, `CSBR Graph dynamic shape`, `Matmul dynamic shape` and `Bias + Matmul dynamic shape` to showcase use of dynamic shapes and kernel cache.

- Two new APIs to describe the plan in the form engine number and knobs are introduced.
```
error_t
get_plan_name(std::string &name) const;

error_t
get_plan_name_at_index(int64_t plan_index, std::string &name) const;
```
Note:
This name can be used later if you want to deselect_plan_by_name, if run into any potential errors.

- Added an API to query tensor attributes from its UID in a graph.
`query_tensor_with_uid(int64_t const uid, Tensor_attributes &tensor) const;`

## Improvements

- sdpa fp16 bprop node can now compute dbias when padding mask is enabled.

- sdpa fp8 (forward and bprop) nodes now support optional bias, dropout and padding mask.

- Matmul fp8 node can now accept M,N,K overrides.

- Added new python notebooks for implementing BatchNorm and BatchNorm bprop using cuDNN.

- Updated [benchmark numbers](benchmark) with cudnn 9.4.0 for fp16 and fp8 datatypes.

- Fixed compilation issues when `NV_CUDNN_DISABLE_EXCEPTION` is enabled.

## Bug fixes

- Fixed a crash when the output dimension of dgrad node is not specified. This now returns an error message instead.

- Fixed incorrect SDPA stats stride inferencing.

- Fixed a bug in sdpa test when sliding window attention is enabled and query sequence length (s_q) is greater than key length (s_kv). This case is now not supported.
  • Loading branch information
Anerudhan authored Sep 23, 2024
1 parent 9f8cc9a commit de355c7
Show file tree
Hide file tree
Showing 65 changed files with 2,727 additions and 491 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.17)

project(cudnn_frontend VERSION 1.6.1)
project(cudnn_frontend VERSION 1.7.0)

option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)
Expand Down
22 changes: 20 additions & 2 deletions README.FE.1.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ FE v1.0 API follows a functional style of building a graph. Operations take in i
| [Scale dot product attention](docs/operations/Attention.md) | sdpa<br> SDPA_attributes | sdpa |
| [Scale dot product attention backward](docs/operations/Attention.md) | sdpa_backward<br> SDPA_backward_attributes | sdpa_backward |
| [Scale dot product attention FP8](docs/operations/Attention.md) | sdpa_fp8<br> SDPA_fp8_attributes | sdpa_fp8 |
| [Scale dot product attention backward FP8](docs/operations/Attention.md) | sdpa_fp8_backward<br> SDPA_fp8_backward_attributes | sdpa_fp8_backward
| Slice | slice<br> Slice_attributes | slice |
| [Scale dot product attention backward FP8](docs/operations/Attention.md) | sdpa_fp8_backward<br> SDPA_fp8_backward_attributes | sdpa_fp8_backward |
| [Slice](docs/operations/Slice.md) | slice<br> Slice_attributes | slice |

### Creating the Graph
Instantiate an object of class `cudnn_frontend::graph::Graph` which will house tensors and operations.
Expand Down Expand Up @@ -156,6 +156,7 @@ cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::select_behavior_note
cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_numeric_notes(std::vector<cudnn_frontend::NumericalNote_t> const&);
cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_behavior_notes(std::vector<cudnn_frontend::BehaviorNote_t> const&);
cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_workspace_greater_than(int64_t const workspace);
cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_shared_mem_greater_than(int64_t const shared_memory);
```

### Autotuning
Expand Down Expand Up @@ -205,6 +206,23 @@ Get workspace to run autotune on all plans.

`get_autotune_workspace_size() const`


### Serialization

Frontend v1.0 API provides two flavors of serialization. One is to checkpoint after the initial graph specification (before calling validate) and other after building the execution plan (to save on plan creation).

`void serialize(json &j) const`
`void deserialize(const json &j)`
The above two APIs are meant to capture the user specified input tensors and nodes into the graph. This can be used to generate the log (for debugging) or to visualize the graph being created.

`error_t serialize(std::vector<uint8_t> &data) const`
`error_t deserialize(cudnnHandle_t handle, std::vector<uint8_t> const &data)`

A fully built graph can be serialized into a binary blob of data with the above two APIs.
Note:
1. Not all engine configs support serialization.
2. It is the users responsibility to make sure the UIDs of tensor being passed to the variant pack remain consistent before and after serialization.

### Error handling

C++ API returns a error object which has a error code and error message.
Expand Down
6 changes: 3 additions & 3 deletions benchmark/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:24.03-py3
FROM nvcr.io/nvidia/pytorch:24.07-py3

RUN apt-get update && \
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
Expand All @@ -9,12 +9,12 @@ RUN apt-get update && \

RUN pip uninstall -y cudnn

RUN CMAKE_BUILD_PARALLEL_LEVEL=16 pip install git+https://github.com/NVIDIA/cudnn-frontend.git -v
RUN pip install nvidia-cudnn-frontend

COPY benchmark_flash_attention.py .

ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH

CMD ["python", "/workspace/benchmark_flash_attention.py"]
CMD ["python", "benchmark_flash_attention.py"]

WORKDIR /workspace
Loading

0 comments on commit de355c7

Please sign in to comment.