16
16
import dataclasses
17
17
import itertools
18
18
import logging
19
+ import math
19
20
import textwrap
20
21
from typing import (
21
22
Any ,
@@ -75,17 +76,22 @@ def rechunking_plan(
75
76
source_chunks : Mapping [str , int ],
76
77
target_chunks : Mapping [str , int ],
77
78
itemsize : int ,
79
+ min_mem : int ,
78
80
max_mem : int ,
79
- ) -> List [Dict [str , int ]]:
81
+ ) -> List [List [ Dict [str , int ] ]]:
80
82
"""Make a rechunking plan."""
81
- plan_shapes = algorithm .rechunking_plan (
83
+ stages = algorithm .multistage_rechunking_plan (
82
84
shape = tuple (dim_sizes .values ()),
83
85
source_chunks = tuple (source_chunks [dim ] for dim in dim_sizes ),
84
86
target_chunks = tuple (target_chunks [dim ] for dim in dim_sizes ),
85
87
itemsize = itemsize ,
88
+ min_mem = min_mem ,
86
89
max_mem = max_mem ,
87
90
)
88
- return [dict (zip (dim_sizes .keys (), shapes )) for shapes in plan_shapes ]
91
+ plan = []
92
+ for stage in stages :
93
+ plan .append ([dict (zip (dim_sizes .keys (), shapes )) for shapes in stage ])
94
+ return plan
89
95
90
96
91
97
def _consolidate_chunks_in_var_group (
@@ -511,7 +517,8 @@ def __init__(
511
517
source_chunks : Mapping [str , Union [int , Tuple [int , ...]]],
512
518
target_chunks : Mapping [str , Union [int , Tuple [int , ...]]],
513
519
itemsize : int ,
514
- max_mem : int = 2 ** 30 , # 1 GB
520
+ min_mem : Optional [int ] = None ,
521
+ max_mem : int = 2 ** 30 , # 1 GB
515
522
):
516
523
"""Initialize Rechunk().
517
524
@@ -524,13 +531,16 @@ def __init__(
524
531
itemsize: approximate number of bytes per xarray.Dataset element, after
525
532
indexing out by all dimensions, e.g., `4 * len(dataset)` for float32
526
533
data or roughly `dataset.nbytes / np.prod(dataset.sizes)`.
534
+ min_mem: minimum memory that a single intermediate chunk must consume.
527
535
max_mem: maximum memory that a single intermediate chunk may consume.
528
536
"""
529
537
if source_chunks .keys () != target_chunks .keys ():
530
538
raise ValueError (
531
539
'source_chunks and target_chunks have different keys: '
532
540
f'{ source_chunks } vs { target_chunks } '
533
541
)
542
+ if min_mem is None :
543
+ min_mem = max_mem // 100
534
544
self .dim_sizes = dim_sizes
535
545
self .source_chunks = normalize_chunks (source_chunks , dim_sizes )
536
546
self .target_chunks = normalize_chunks (target_chunks , dim_sizes )
@@ -539,27 +549,29 @@ def __init__(
539
549
self .source_chunks ,
540
550
self .target_chunks ,
541
551
itemsize = itemsize ,
552
+ min_mem = min_mem ,
542
553
max_mem = max_mem ,
543
554
)
544
- self .read_chunks , self .intermediate_chunks , self .write_chunks = plan
555
+ plan = (
556
+ [[self .source_chunks , self .source_chunks , plan [0 ][0 ]]]
557
+ + plan
558
+ + [[plan [- 1 ][- 1 ], self .target_chunks , self .target_chunks ]]
559
+ )
560
+ self .stage_in , (_ , * intermediates , _ ), self .stage_out = zip (* plan )
545
561
546
- # TODO(shoyer): multi-stage rechunking, when supported by rechunker:
547
- # https://github.com/pangeo-data/rechunker/pull/89
548
- self .stage_in = [self .source_chunks , self .read_chunks , self .write_chunks ]
549
- self .stage_out = [self .read_chunks , self .write_chunks , self .target_chunks ]
550
562
logging .info (
551
563
'Rechunking plan:\n '
552
564
+ '\n ' .join (
553
- f'{ s } -> { t } ' for s , t in zip (self .stage_in , self .stage_out )
565
+ f'Stage{ i } : { s } -> { t } '
566
+ for i , (s , t ) in enumerate (zip (self .stage_in , self .stage_out ))
554
567
)
555
568
)
556
- min_size = itemsize * np .prod (list (self .intermediate_chunks .values ()))
569
+ min_size = min (
570
+ itemsize * math .prod (chunks .values ()) for chunks in intermediates
571
+ )
557
572
logging .info (f'Smallest intermediates have size { min_size :1.3e} ' )
558
573
559
574
def expand (self , pcoll ):
560
- # TODO(shoyer): consider splitting xarray.Dataset objects into separate
561
- # arrays for rechunking, which is more similar to what Rechunker does and
562
- # in principle could be more efficient.
563
575
for stage , (in_chunks , out_chunks ) in enumerate (
564
576
zip (self .stage_in , self .stage_out )
565
577
):
0 commit comments