Skip to content

Commit

Permalink
Caches expensive cudnn graph validation (Lightning-AI#271)
Browse files Browse the repository at this point in the history
Co-authored-by: Vedaanta Agarwalla <vagarwalla@ipp2-1949.nvidia.com>
  • Loading branch information
vedaanta and Vedaanta Agarwalla authored Apr 25, 2024
1 parent 34d21e1 commit 2578766
Showing 1 changed file with 2 additions and 16 deletions.
18 changes: 2 additions & 16 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,10 @@ def _make_cudnn_sdpa_forward_graph(query, key, value, attn_mask, dropout_p, is_c

softmax_stats.set_output(True).set_data_type(torch_to_cudnn_dtype(torch.float32))

# Validate the graph before querying the cache key
# Validation makes sure all missing properties are inferred and filled, as they affect cache key.
graph.validate()
cache_key = graph.key()

# If a built graph does not exist in cache already, make one and place it in
if cache_key not in _cudnnex_cache:
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A])
graph.check_support()
graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE)
graph.build([cudnn.heur_mode.A])

_cudnnex_cache[cache_key] = (
Q,
Expand Down Expand Up @@ -464,17 +457,10 @@ def _make_cudnn_sdpa_backward_graph(
torch_to_cudnn_dtype(value.dtype)
)

# Validate the graph before querying the cache key
# Validation makes sure all missing properties are inferred and filled, as they affect cache key.
graph.validate()
cache_key = graph.key()

# If a built graph does not exist in cache already, make one and place it in
if cache_key not in _cudnnex_cache:
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A])
graph.check_support()
graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE)
graph.build([cudnn.heur_mode.A])

_cudnnex_cache[cache_key] = (
Q,
Expand Down

0 comments on commit 2578766

Please sign in to comment.