diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index caa937db2..420c0ea24 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -73,7 +73,6 @@ def shards_all_to_all( sharded_t = state_dict[extend_shard_name(shard_name)] assert param.ranks is not None dst_ranks = param.ranks - state_dict[extend_shard_name(shard_name)] # pyre-ignore src_ranks = module.module_sharding_plan[shard_name].ranks @@ -140,7 +139,7 @@ def shards_all_to_all( input=local_input_tensor, output_split_sizes=local_output_splits, input_split_sizes=local_input_splits, - group=dist.group.WORLD, + group=env.process_group, # TODO: 2D uses env.sharding_pg ) flattened_output_names_lengths = [