diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 32b1cb6..4740aea 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -1230,6 +1230,7 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int, def get_latency_fwd_per_layer_shared_dp_comm(self) -> float: dp_size = self.parallelism_config.dp_size ep_size = self.parallelism_config.ep_size + tp_size = self.parallelism_config.tp_size def time_allgather(S, n, B): # https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allgather @@ -1243,15 +1244,17 @@ def time_allgather(S, n, B): self.get_num_params_per_layer_layernorm() ) * self.dtype_config.weight_bits / BITS_PER_BYTE - latency_allgather_params_mlp = time_allgather( - params_bytes_mlp, dp_size / ep_size, - (self.get_intra_node_bandwidth() - if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9) + # assuming tp and dp are preferred when sharding intra node, pp is only applied across nodes + # when (dp_size * tp_size) <= 8, the data parallel processes are within a node + bandwidth = self.get_intra_node_bandwidth() if ( + dp_size * tp_size) <= 8 else self.get_inter_node_bandwidth() + + latency_allgather_params_mlp = time_allgather(params_bytes_mlp, + dp_size / ep_size, + bandwidth * 10**9) latency_allgather_params_non_mlp = time_allgather( - params_bytes_non_mlp, dp_size, - (self.get_intra_node_bandwidth() - if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9) + params_bytes_non_mlp, dp_size, bandwidth * 10**9) latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp