Skip to content

Commit 8fe2ecd

Browse files
committed
Refactor Metadata in Transaction
Today, we have a copy of the `TableMetadata` on the `Table` and the `Transaction`. This PR changes that logic to re-use the one on the table, and add the changes to the one on the `Transaction`. This also allows us to stack changes, for example, to first change a schema, and then write data with the new schema right away. Also a prerequisite for #1772
1 parent 76d02ad commit 8fe2ecd

File tree

4 files changed

+42
-8
lines changed

4 files changed

+42
-8
lines changed

pyiceberg/table/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ class TableProperties:
243243

244244
class Transaction:
245245
_table: Table
246-
table_metadata: TableMetadata
247246
_autocommit: bool
248247
_updates: Tuple[TableUpdate, ...]
249248
_requirements: Tuple[TableRequirement, ...]
@@ -255,12 +254,15 @@ def __init__(self, table: Table, autocommit: bool = False):
255254
table: The table that will be altered.
256255
autocommit: Option to automatically commit the changes when they are staged.
257256
"""
258-
self.table_metadata = table.metadata
259257
self._table = table
260258
self._autocommit = autocommit
261259
self._updates = ()
262260
self._requirements = ()
263261

262+
@property
263+
def table_metadata(self) -> TableMetadata:
264+
return update_table_metadata(self._table.metadata, self._updates)
265+
264266
def __enter__(self) -> Transaction:
265267
"""Start a transaction to update the table."""
266268
return self
@@ -286,8 +288,6 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ
286288
if type(new_requirement) not in existing_requirements:
287289
self._requirements = self._requirements + (new_requirement,)
288290

289-
self.table_metadata = update_table_metadata(self.table_metadata, updates)
290-
291291
if self._autocommit:
292292
self.commit_transaction()
293293
self._updates = ()

pyiceberg/table/update/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta
360360
@_apply_table_update.register(AddPartitionSpecUpdate)
361361
def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
362362
for spec in base_metadata.partition_specs:
363-
if spec.spec_id == update.spec.spec_id:
363+
# Only raise in case of a discrepancy
364+
if spec.spec_id == update.spec.spec_id and spec != update.spec:
364365
raise ValueError(f"Partition spec with id {spec.spec_id} already exists: {spec}")
365366

366367
metadata_updates: Dict[str, Any] = {
@@ -525,6 +526,11 @@ def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _T
525526

526527
@_apply_table_update.register(AddSortOrderUpdate)
527528
def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
529+
for sort in base_metadata.sort_orders:
530+
# Only raise in case of a discrepancy
531+
if sort.order_id == update.sort_order.order_id and sort != update.sort_order:
532+
raise ValueError(f"Sort-order with id {sort.order_id} already exists: {sort}")
533+
528534
context.add_update(update)
529535
return base_metadata.model_copy(
530536
update={

tests/integration/test_rest_schema.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_schema_evolution_via_transaction(catalog: Catalog) -> None:
154154
NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False),
155155
)
156156

157-
with pytest.raises(CommitFailedException) as exc_info:
157+
with pytest.raises(CommitFailedException, match="Requirement failed: current schema id has changed: expected 2, found 3"):
158158
with tbl.transaction() as tx:
159159
# Start a new update
160160
schema_update = tx.update_schema()
@@ -165,8 +165,6 @@ def test_schema_evolution_via_transaction(catalog: Catalog) -> None:
165165
# stage another update in the transaction
166166
schema_update.add_column("col_double", DoubleType()).commit()
167167

168-
assert "Requirement failed: current schema changed: expected id 2 != 3" in str(exc_info.value)
169-
170168
assert tbl.schema() == Schema(
171169
NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False),
172170
NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False),

tests/integration/test_writes/test_writes.py

+30
Original file line numberDiff line numberDiff line change
@@ -1683,3 +1683,33 @@ def test_write_optional_list(session_catalog: Catalog) -> None:
16831683
session_catalog.load_table(identifier).append(df_2)
16841684

16851685
assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4
1686+
1687+
1688+
@pytest.mark.integration
1689+
@pytest.mark.parametrize("format_version", [1, 2])
1690+
def test_evolve_and_write(
1691+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
1692+
) -> None:
1693+
identifier = "default.test_evolve_and_write"
1694+
tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema())
1695+
1696+
numbers = pa.array([1, 2, 3, 4], type=pa.int32())
1697+
1698+
with tbl.transaction() as tx:
1699+
with tx.update_schema() as upd:
1700+
upd.add_column("id", IntegerType())
1701+
1702+
tx.append(
1703+
pa.Table.from_arrays(
1704+
[
1705+
numbers,
1706+
],
1707+
schema=pa.schema(
1708+
[
1709+
pa.field("id", pa.int32(), nullable=True),
1710+
]
1711+
),
1712+
)
1713+
)
1714+
1715+
assert tbl.scan().to_arrow().column(0).combine_chunks() == numbers

0 commit comments

Comments
 (0)