Skip to content

Apply residuals when reading a table #1654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,9 +1381,7 @@ def _get_column_projection_values(
def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
Expand All @@ -1401,8 +1399,8 @@ def _task_to_record_batches(
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
if task.residual is not AlwaysTrue():
translated_row_filter = translate_column_names(task.residual, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

Expand All @@ -1412,7 +1410,13 @@ def _task_to_record_batches(
task.file, projected_schema, partition_spec, file_schema.field_ids
)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
file_project_schema = prune_columns(
file_schema,
{
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(task.residual)),
select_full_types=False,
)

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1506,7 +1510,7 @@ class ArrowScan:
_table_metadata: TableMetadata
_io: FileIO
_projected_schema: Schema
_bound_row_filter: BooleanExpression
_bound_row_filter: Optional[BooleanExpression]
_case_sensitive: bool
_limit: Optional[int]
"""Scan the Iceberg Table and create an Arrow construct.
Expand All @@ -1525,26 +1529,25 @@ def __init__(
table_metadata: TableMetadata,
io: FileIO,
projected_schema: Schema,
row_filter: BooleanExpression,
row_filter: Optional[BooleanExpression] = None,
case_sensitive: bool = True,
limit: Optional[int] = None,
) -> None:
self._table_metadata = table_metadata
self._io = io
self._projected_schema = projected_schema
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
if row_filter is not None:
deprecation_message(
deprecated_in="0.9.0",
removed_in="0.10.0",
help_message="row_filter is marked as deprecated, and will be removed in 0.10.0. Please make sure to set the residual on the ScanTasks.",
)
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
else:
self._bound_row_filter = None
self._case_sensitive = case_sensitive
self._limit = limit

@property
def _projected_field_ids(self) -> Set[int]:
"""Set of field IDs that should be projected from the data files."""
return {
id
for id in self._projected_schema.field_ids
if not isinstance(self._projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(self._bound_row_filter))

def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
"""Scan the Iceberg table and return a pa.Table.

Expand All @@ -1565,7 +1568,10 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
deletes_per_file = _read_all_delete_files(self._io, tasks)
executor = ExecutorFactory.get_or_create()

def _table_from_scan_task(task: FileScanTask) -> pa.Table:
if self._bound_row_filter is not None:
tasks = [task._set_residual(expr=self._bound_row_filter) for task in tasks]

def _table_from_scan_task(task: FileScanTask) -> Optional[pa.Table]:
batches = list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
if len(batches) > 0:
return pa.Table.from_batches(batches)
Expand Down Expand Up @@ -1635,6 +1641,9 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
ResolveError: When a required field cannot be found in the file
ValueError: When a field type in the file cannot be projected to the schema type
"""
if self._bound_row_filter is not None:
tasks = [task._set_residual(expr=self._bound_row_filter) for task in tasks]

deletes_per_file = _read_all_delete_files(self._io, tasks)
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)

Expand All @@ -1648,9 +1657,7 @@ def _record_batches_from_scan_tasks_and_deletes(
batches = _task_to_record_batches(
_fs_from_file_path(self._io, task.file.file_path),
task,
self._bound_row_filter,
self._projected_schema,
self._projected_field_ids,
deletes_per_file.get(task.file.file_path),
self._case_sensitive,
self._table_metadata.name_mapping(),
Expand Down
12 changes: 10 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,10 @@ def __init__(
self.length = length or data_file.file_size_in_bytes
self.residual = residual

def _set_residual(self, expr: BooleanExpression) -> "FileScanTask":
self.residual = expr
return self


def _open_manifest(
io: FileIO,
Expand Down Expand Up @@ -1741,8 +1745,12 @@ def plan_files(self) -> Iterable[FileScanTask]:
data_entry,
positional_delete_entries,
),
residual=residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for(
data_entry.data_file.partition
residual=bind(
self.table_metadata.schema(),
residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for(
data_entry.data_file.partition
),
case_sensitive=self.case_sensitive,
),
)
for data_entry in data_entries
Expand Down
3 changes: 2 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ def project(
partition={},
record_count=3,
file_size_in_bytes=3,
)
),
residual=expr or AlwaysTrue(),
)
for file in files
]
Expand Down
Loading