Skip to content

Commit 9380408

Browse files
shoyerXarray-Beam authors
authored and
Xarray-Beam authors
committed
Use multi-stage rechunking.
As described in pangeo-data/rechunker#89, this can yield significant performance benefits for rechunking large arrays. PiperOrigin-RevId: 518325665
1 parent 443aeae commit 9380408

File tree

5 files changed

+45
-20
lines changed

5 files changed

+45
-20
lines changed

examples/era5_rechunk.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path')
2222
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path')
23-
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
23+
RUNNER = flags.DEFINE_string('runner', None, help='beam.runners.Runner')
2424

2525

2626
# pylint: disable=expression-not-assigned

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
'apache_beam>=2.31.0',
2121
'dask',
2222
'immutabledict',
23-
'rechunker',
23+
'rechunker>=0.5.1',
2424
'zarr',
2525
'xarray',
2626
]
@@ -42,7 +42,7 @@
4242

4343
setuptools.setup(
4444
name='xarray-beam',
45-
version='0.5.1',
45+
version='0.6.0',
4646
license='Apache 2.0',
4747
author='Google LLC',
4848
author_email='[email protected]',

xarray_beam/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@
4545
DatasetToZarr,
4646
)
4747

48-
__version__ = '0.5.1'
48+
__version__ = '0.6.0'

xarray_beam/_src/rechunk.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import dataclasses
1717
import itertools
1818
import logging
19+
import math
1920
import textwrap
2021
from typing import (
2122
Any,
@@ -75,17 +76,22 @@ def rechunking_plan(
7576
source_chunks: Mapping[str, int],
7677
target_chunks: Mapping[str, int],
7778
itemsize: int,
79+
min_mem: int,
7880
max_mem: int,
79-
) -> List[Dict[str, int]]:
81+
) -> List[List[Dict[str, int]]]:
8082
"""Make a rechunking plan."""
81-
plan_shapes = algorithm.rechunking_plan(
83+
stages = algorithm.multistage_rechunking_plan(
8284
shape=tuple(dim_sizes.values()),
8385
source_chunks=tuple(source_chunks[dim] for dim in dim_sizes),
8486
target_chunks=tuple(target_chunks[dim] for dim in dim_sizes),
8587
itemsize=itemsize,
88+
min_mem=min_mem,
8689
max_mem=max_mem,
8790
)
88-
return [dict(zip(dim_sizes.keys(), shapes)) for shapes in plan_shapes]
91+
plan = []
92+
for stage in stages:
93+
plan.append([dict(zip(dim_sizes.keys(), shapes)) for shapes in stage])
94+
return plan
8995

9096

9197
def _consolidate_chunks_in_var_group(
@@ -511,7 +517,8 @@ def __init__(
511517
source_chunks: Mapping[str, Union[int, Tuple[int, ...]]],
512518
target_chunks: Mapping[str, Union[int, Tuple[int, ...]]],
513519
itemsize: int,
514-
max_mem: int = 2**30, # 1 GB
520+
min_mem: Optional[int] = None,
521+
max_mem: int = 2 ** 30, # 1 GB
515522
):
516523
"""Initialize Rechunk().
517524
@@ -524,13 +531,16 @@ def __init__(
524531
itemsize: approximate number of bytes per xarray.Dataset element, after
525532
indexing out by all dimensions, e.g., `4 * len(dataset)` for float32
526533
data or roughly `dataset.nbytes / np.prod(dataset.sizes)`.
534+
min_mem: minimum memory that a single intermediate chunk must consume.
527535
max_mem: maximum memory that a single intermediate chunk may consume.
528536
"""
529537
if source_chunks.keys() != target_chunks.keys():
530538
raise ValueError(
531539
'source_chunks and target_chunks have different keys: '
532540
f'{source_chunks} vs {target_chunks}'
533541
)
542+
if min_mem is None:
543+
min_mem = max_mem // 100
534544
self.dim_sizes = dim_sizes
535545
self.source_chunks = normalize_chunks(source_chunks, dim_sizes)
536546
self.target_chunks = normalize_chunks(target_chunks, dim_sizes)
@@ -539,27 +549,29 @@ def __init__(
539549
self.source_chunks,
540550
self.target_chunks,
541551
itemsize=itemsize,
552+
min_mem=min_mem,
542553
max_mem=max_mem,
543554
)
544-
self.read_chunks, self.intermediate_chunks, self.write_chunks = plan
555+
plan = (
556+
[[self.source_chunks, self.source_chunks, plan[0][0]]]
557+
+ plan
558+
+ [[plan[-1][-1], self.target_chunks, self.target_chunks]]
559+
)
560+
self.stage_in, (_, *intermediates, _), self.stage_out = zip(*plan)
545561

546-
# TODO(shoyer): multi-stage rechunking, when supported by rechunker:
547-
# https://github.com/pangeo-data/rechunker/pull/89
548-
self.stage_in = [self.source_chunks, self.read_chunks, self.write_chunks]
549-
self.stage_out = [self.read_chunks, self.write_chunks, self.target_chunks]
550562
logging.info(
551563
'Rechunking plan:\n'
552564
+ '\n'.join(
553-
f'{s} -> {t}' for s, t in zip(self.stage_in, self.stage_out)
565+
f'Stage{i}: {s} -> {t}'
566+
for i, (s, t) in enumerate(zip(self.stage_in, self.stage_out))
554567
)
555568
)
556-
min_size = itemsize * np.prod(list(self.intermediate_chunks.values()))
569+
min_size = min(
570+
itemsize * math.prod(chunks.values()) for chunks in intermediates
571+
)
557572
logging.info(f'Smallest intermediates have size {min_size:1.3e}')
558573

559574
def expand(self, pcoll):
560-
# TODO(shoyer): consider splitting xarray.Dataset objects into separate
561-
# arrays for rechunking, which is more similar to what Rechunker does and
562-
# in principle could be more efficient.
563575
for stage, (in_chunks, out_chunks) in enumerate(
564576
zip(self.stage_in, self.stage_out)
565577
):

xarray_beam/_src/rechunk_test.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,24 @@ def test_normalize_chunks_errors(self):
6363

6464
def test_rechunking_plan(self):
6565
# this trivial case fits entirely into memory
66-
plan = rechunk.rechunking_plan(
66+
plan, = rechunk.rechunking_plan(
6767
dim_sizes={'x': 10, 'y': 20},
6868
source_chunks={'x': 1, 'y': 20},
6969
target_chunks={'x': 10, 'y': 1},
7070
itemsize=1,
71+
min_mem=0,
7172
max_mem=200,
7273
)
7374
expected = [{'x': 10, 'y': 20}] * 3
7475
self.assertEqual(plan, expected)
7576

7677
# this harder case doesn't
77-
read_chunks, _, write_chunks = rechunk.rechunking_plan(
78+
(read_chunks, _, write_chunks), = rechunk.rechunking_plan(
7879
dim_sizes={'t': 1000, 'x': 200, 'y': 300},
7980
source_chunks={'t': 1, 'x': 200, 'y': 300},
8081
target_chunks={'t': 1000, 'x': 20, 'y': 20},
8182
itemsize=8,
83+
min_mem=0,
8284
max_mem=10_000_000,
8385
)
8486
self.assertGreater(read_chunks['t'], 1)
@@ -88,6 +90,17 @@ def test_rechunking_plan(self):
8890
self.assertGreater(read_chunks['x'], 20)
8991
self.assertGreater(read_chunks['y'], 20)
9092

93+
# multiple stages
94+
stages = rechunk.rechunking_plan(
95+
dim_sizes={'t': 1000, 'x': 200, 'y': 300},
96+
source_chunks={'t': 1, 'x': 200, 'y': 300},
97+
target_chunks={'t': 1000, 'x': 20, 'y': 20},
98+
itemsize=8,
99+
min_mem=1_000_000,
100+
max_mem=10_000_000,
101+
)
102+
self.assertGreater(len(stages), 1)
103+
91104
def test_consolidate_and_split_chunks(self):
92105
consolidated = [
93106
(

0 commit comments

Comments
 (0)