-
Notifications
You must be signed in to change notification settings - Fork 142
Description
In populate_cuda_graph, the cuDNN graph is first populated to a intermediate CUDA graph first, and then the intermediate CUDA graph is added as a child graph to the user's CUDA graph:
cudnn-frontend/include/cudnn_frontend/graph_interface.h
Lines 588 to 604 in be6c079
| // Finally get the backend cuda graph. | |
| cudaGraph_t backend_cuda_graph; | |
| // Initialize the cudnn cuda graph. | |
| // The responsibility to destroy is on the user. | |
| detail::cu_graph_create(&backend_cuda_graph, 0); // 0 is just what the API says to pass | |
| _CUDNN_CHECK_CUDNN_ERROR(detail::populate_cuda_graph(handle, | |
| plans.execution_plans[candidate]->get_raw_desc(), | |
| variant_pack_descriptor.get_ptr(), | |
| backend_cuda_graph)); | |
| // Clone BE graph into a graph_node | |
| // This same call also places the newly created into FE's graph | |
| // TODO: BE graph is at the end, so put in appropriate dependencies | |
| cudaGraphNode_t backend_cuda_graph_node; | |
| detail::cuda_graph_add_child_graph_node( | |
| &backend_cuda_graph_node, cudnn_cuda_graph, &last_node, last_node != nullptr, backend_cuda_graph); |
which I think is just adding overhead and complexity.
When using the populate_cuda_graph APIs, users are always using an intermediate CUDA graph to store the results from cuDNN, i.e.:
CudaGraph cuda_graph(device);
graph.populate_cuda_graph(..., cuda_graph);
cudaGraphNode_t node;
cudaGraphAddChildGraphNode(&node, root_cuda_graph, NULL, 0, cuda_graph));I guess the current behavior of populate_cuda_graph assumed use case like below so it handles the work of creating an intermediate CUDA graph:
graph.populate_cuda_graph(..., root_cuda_graph);However when updating the CUDA graph the assumption would break because you can't do things like below:
graph.update_cuda_graph(..., root_cuda_graph);So to use the native CUDA graph APIs the users always need to create an intermediate CUDA graph themselves, and the final graph would look like this:
Root graph -> Intermediate child graph by user -> Intermediate child graph by cuDNN -> Actual kernel
and the intermediate child graph by cuDNN would be redundant.
Refs ml-explore/mlx#2857.