Skip to content

Commit 787e711

Browse files
Caner Gocmenfacebook-github-bot
Caner Gocmen
authored andcommitted
Update critical path definition (#2879)
Summary: Pull Request resolved: #2879 Update the critical path definition in the planner logs to match what we think is the most realistic option. See the comments for the `_calculate_critical_path` function for the detailed logic. Reviewed By: iamzainhuda Differential Revision: D72410003 fbshipit-source-id: 050bed84b8bbe1b3b9489f3093a2297277cdac57
1 parent 8ec0540 commit 787e711

File tree

2 files changed

+96
-2
lines changed

2 files changed

+96
-2
lines changed

torchrec/distributed/planner/stats.py

+87-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Union,
2626
)
2727

28+
import pandas as pd
2829
from torch import nn
2930

3031
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -36,6 +37,7 @@
3637
InferenceStorageReservation,
3738
)
3839
from torchrec.distributed.planner.types import (
40+
CriticalPathEstimate,
3941
ParameterConstraints,
4042
Perf,
4143
ShardingOption,
@@ -319,7 +321,7 @@ def log(
319321
)
320322

321323
# Max perf and HBM to help root cause imbalance
322-
self._log_max_perf_and_max_hbm(perf, used_hbm)
324+
self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan)
323325
self._log_storage_reservation_stats(
324326
storage_reservation,
325327
topology,
@@ -445,7 +447,9 @@ def _log_plan_imbalance_stats(
445447
f"# {'Imbalance stats range 0-1, higher means more imbalanced' : <{self._width-3}}#"
446448
)
447449

448-
def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> None:
450+
def _log_max_perf_and_max_hbm(
451+
self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption]
452+
) -> None:
449453
total_perfs = [perf.total for perf in perfs]
450454

451455
max_total_perf_text = f"Longest Critical Path (Maximum of Total Perf): {_generate_max_text(total_perfs)}"
@@ -480,6 +484,8 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
480484
)
481485
sum_of_maxima_text = f"Sum of Maxima: {round(sum_of_maxima, 3)} ms"
482486

487+
critical_path_estimate = _calculate_critical_path(best_plan)
488+
483489
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
484490
self._stats_table.append(f"# {max_total_perf_text : <{self._width-3}}#")
485491
self._stats_table.append(f"# {mean_total_perf_text : <{self._width-3}}#")
@@ -512,6 +518,15 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
512518
self._stats_table.append(
513519
f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#"
514520
)
521+
self._stats_table.append(
522+
f"# {'Critical Path (comms): '+str(round(critical_path_estimate.comms_estimate, 3)) : <{self._width-3}}#"
523+
)
524+
self._stats_table.append(
525+
f"# {'Critical Path (compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#"
526+
)
527+
self._stats_table.append(
528+
f"# {'Critical Path (comms + compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#"
529+
)
515530

516531
max_used_hbm = max(used_hbm)
517532
mean_used_hbm = statistics.mean(used_hbm)
@@ -1052,6 +1067,76 @@ def _reduce_int_list(input_list: List[int]) -> str:
10521067
return ", ".join(reduced)
10531068

10541069

1070+
def _calculate_critical_path(best_plan: List[ShardingOption]) -> CriticalPathEstimate:
1071+
"""
1072+
Calculates the critical path of the sharding plan. Makes the following assumptions:
1073+
1074+
1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp.
1075+
2. There are additional synchronization points during communication (both fwd & bwd) for each module <> sharding type combination.
1076+
i. Communication operations for each shard from the same module <> sharding type group are executed sequentially.
1077+
ii. Ranks need to synchronize before they can begin the communication operation for the next module <> sharding type group.
1078+
3. There are additional synchronization points during computation (both fwd & bwd) at the rank level.
1079+
i. Computation operations for each shard from the same module are executed sequentially.
1080+
ii. Ranks need to synchronize before they can begin the next set of events.
1081+
"""
1082+
1083+
perf_data = defaultdict(float)
1084+
for so in best_plan:
1085+
module = so.module
1086+
sharding_type = so.sharding_type
1087+
ranks = sorted([cast(int, shard.rank) for shard in so.shards])
1088+
shard_perfs = [cast(Perf, shard.perf) for shard in so.shards]
1089+
perf_breakdowns = [
1090+
{
1091+
"fwd_compute": perf.fwd_compute,
1092+
"fwd_comms": perf.fwd_comms,
1093+
"bwd_compute": perf.bwd_compute,
1094+
"bwd_comms": perf.bwd_comms,
1095+
"prefetch_compute": perf.prefetch_compute,
1096+
}
1097+
for perf in shard_perfs
1098+
]
1099+
1100+
for rank, perf_breakdown in zip(ranks, perf_breakdowns):
1101+
for perf_type in perf_breakdown:
1102+
perf_data[
1103+
(
1104+
rank,
1105+
module,
1106+
sharding_type,
1107+
perf_type.split("_")[0], # fwd or bwd
1108+
perf_type.split("_")[1], # compute or comms
1109+
)
1110+
] += perf_breakdown[perf_type]
1111+
perf_df = pd.DataFrame.from_dict(perf_data, orient="index", columns=["perf"])
1112+
perf_df.index = pd.MultiIndex.from_tuples(
1113+
perf_df.index,
1114+
names=["rank", "module", "sharding_type", "direction", "perf_type"],
1115+
)
1116+
1117+
comms_estimate = (
1118+
perf_df.xs("comms", level="perf_type")
1119+
.groupby(["rank", "module", "sharding_type", "direction"])
1120+
.sum()
1121+
.groupby(["module", "sharding_type", "direction"])
1122+
.max()
1123+
.sum()
1124+
.item()
1125+
)
1126+
1127+
comp_estimate = (
1128+
perf_df.xs("compute", level="perf_type")
1129+
.groupby(["rank", "direction"])
1130+
.sum()
1131+
.groupby(["direction"])
1132+
.max()
1133+
.sum()
1134+
.item()
1135+
)
1136+
1137+
return CriticalPathEstimate(comms_estimate, comp_estimate)
1138+
1139+
10551140
class NoopEmbeddingStats(Stats):
10561141
"""
10571142
Noop Stats for a sharding planner execution.

torchrec/distributed/planner/types.py

+9
Original file line numberDiff line numberDiff line change
@@ -831,3 +831,12 @@ def log(
831831
See class description
832832
"""
833833
...
834+
835+
836+
@dataclass
837+
class CriticalPathEstimate:
838+
comms_estimate: float
839+
comp_estimate: float
840+
841+
def total(self) -> float:
842+
return self.comms_estimate + self.comp_estimate

0 commit comments

Comments
 (0)