|
25 | 25 | Union,
|
26 | 26 | )
|
27 | 27 |
|
| 28 | +import pandas as pd |
28 | 29 | from torch import nn
|
29 | 30 |
|
30 | 31 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
|
36 | 37 | InferenceStorageReservation,
|
37 | 38 | )
|
38 | 39 | from torchrec.distributed.planner.types import (
|
| 40 | + CriticalPathEstimate, |
39 | 41 | ParameterConstraints,
|
40 | 42 | Perf,
|
41 | 43 | ShardingOption,
|
@@ -319,7 +321,7 @@ def log(
|
319 | 321 | )
|
320 | 322 |
|
321 | 323 | # Max perf and HBM to help root cause imbalance
|
322 |
| - self._log_max_perf_and_max_hbm(perf, used_hbm) |
| 324 | + self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan) |
323 | 325 | self._log_storage_reservation_stats(
|
324 | 326 | storage_reservation,
|
325 | 327 | topology,
|
@@ -445,7 +447,9 @@ def _log_plan_imbalance_stats(
|
445 | 447 | f"# {'Imbalance stats range 0-1, higher means more imbalanced' : <{self._width-3}}#"
|
446 | 448 | )
|
447 | 449 |
|
448 |
| - def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> None: |
| 450 | + def _log_max_perf_and_max_hbm( |
| 451 | + self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption] |
| 452 | + ) -> None: |
449 | 453 | total_perfs = [perf.total for perf in perfs]
|
450 | 454 |
|
451 | 455 | max_total_perf_text = f"Longest Critical Path (Maximum of Total Perf): {_generate_max_text(total_perfs)}"
|
@@ -480,6 +484,8 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
|
480 | 484 | )
|
481 | 485 | sum_of_maxima_text = f"Sum of Maxima: {round(sum_of_maxima, 3)} ms"
|
482 | 486 |
|
| 487 | + critical_path_estimate = _calculate_critical_path(best_plan) |
| 488 | + |
483 | 489 | self._stats_table.append(f"#{'' : ^{self._width-2}}#")
|
484 | 490 | self._stats_table.append(f"# {max_total_perf_text : <{self._width-3}}#")
|
485 | 491 | self._stats_table.append(f"# {mean_total_perf_text : <{self._width-3}}#")
|
@@ -512,6 +518,15 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
|
512 | 518 | self._stats_table.append(
|
513 | 519 | f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#"
|
514 | 520 | )
|
| 521 | + self._stats_table.append( |
| 522 | + f"# {'Critical Path (comms): '+str(round(critical_path_estimate.comms_estimate, 3)) : <{self._width-3}}#" |
| 523 | + ) |
| 524 | + self._stats_table.append( |
| 525 | + f"# {'Critical Path (compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#" |
| 526 | + ) |
| 527 | + self._stats_table.append( |
| 528 | + f"# {'Critical Path (comms + compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#" |
| 529 | + ) |
515 | 530 |
|
516 | 531 | max_used_hbm = max(used_hbm)
|
517 | 532 | mean_used_hbm = statistics.mean(used_hbm)
|
@@ -1052,6 +1067,76 @@ def _reduce_int_list(input_list: List[int]) -> str:
|
1052 | 1067 | return ", ".join(reduced)
|
1053 | 1068 |
|
1054 | 1069 |
|
| 1070 | +def _calculate_critical_path(best_plan: List[ShardingOption]) -> CriticalPathEstimate: |
| 1071 | + """ |
| 1072 | + Calculates the critical path of the sharding plan. Makes the following assumptions: |
| 1073 | +
|
| 1074 | + 1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp. |
| 1075 | + 2. There are additional synchronization points during communication (both fwd & bwd) for each module <> sharding type combination. |
| 1076 | + i. Communication operations for each shard from the same module <> sharding type group are executed sequentially. |
| 1077 | + ii. Ranks need to synchronize before they can begin the communication operation for the next module <> sharding type group. |
| 1078 | + 3. There are additional synchronization points during computation (both fwd & bwd) at the rank level. |
| 1079 | + i. Computation operations for each shard from the same module are executed sequentially. |
| 1080 | + ii. Ranks need to synchronize before they can begin the next set of events. |
| 1081 | + """ |
| 1082 | + |
| 1083 | + perf_data = defaultdict(float) |
| 1084 | + for so in best_plan: |
| 1085 | + module = so.module |
| 1086 | + sharding_type = so.sharding_type |
| 1087 | + ranks = sorted([cast(int, shard.rank) for shard in so.shards]) |
| 1088 | + shard_perfs = [cast(Perf, shard.perf) for shard in so.shards] |
| 1089 | + perf_breakdowns = [ |
| 1090 | + { |
| 1091 | + "fwd_compute": perf.fwd_compute, |
| 1092 | + "fwd_comms": perf.fwd_comms, |
| 1093 | + "bwd_compute": perf.bwd_compute, |
| 1094 | + "bwd_comms": perf.bwd_comms, |
| 1095 | + "prefetch_compute": perf.prefetch_compute, |
| 1096 | + } |
| 1097 | + for perf in shard_perfs |
| 1098 | + ] |
| 1099 | + |
| 1100 | + for rank, perf_breakdown in zip(ranks, perf_breakdowns): |
| 1101 | + for perf_type in perf_breakdown: |
| 1102 | + perf_data[ |
| 1103 | + ( |
| 1104 | + rank, |
| 1105 | + module, |
| 1106 | + sharding_type, |
| 1107 | + perf_type.split("_")[0], # fwd or bwd |
| 1108 | + perf_type.split("_")[1], # compute or comms |
| 1109 | + ) |
| 1110 | + ] += perf_breakdown[perf_type] |
| 1111 | + perf_df = pd.DataFrame.from_dict(perf_data, orient="index", columns=["perf"]) |
| 1112 | + perf_df.index = pd.MultiIndex.from_tuples( |
| 1113 | + perf_df.index, |
| 1114 | + names=["rank", "module", "sharding_type", "direction", "perf_type"], |
| 1115 | + ) |
| 1116 | + |
| 1117 | + comms_estimate = ( |
| 1118 | + perf_df.xs("comms", level="perf_type") |
| 1119 | + .groupby(["rank", "module", "sharding_type", "direction"]) |
| 1120 | + .sum() |
| 1121 | + .groupby(["module", "sharding_type", "direction"]) |
| 1122 | + .max() |
| 1123 | + .sum() |
| 1124 | + .item() |
| 1125 | + ) |
| 1126 | + |
| 1127 | + comp_estimate = ( |
| 1128 | + perf_df.xs("compute", level="perf_type") |
| 1129 | + .groupby(["rank", "direction"]) |
| 1130 | + .sum() |
| 1131 | + .groupby(["direction"]) |
| 1132 | + .max() |
| 1133 | + .sum() |
| 1134 | + .item() |
| 1135 | + ) |
| 1136 | + |
| 1137 | + return CriticalPathEstimate(comms_estimate, comp_estimate) |
| 1138 | + |
| 1139 | + |
1055 | 1140 | class NoopEmbeddingStats(Stats):
|
1056 | 1141 | """
|
1057 | 1142 | Noop Stats for a sharding planner execution.
|
|
0 commit comments