diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8f7b45f532..e2108004e8 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -243,7 +243,6 @@ class TableProperties: class Transaction: _table: Table - table_metadata: TableMetadata _autocommit: bool _updates: Tuple[TableUpdate, ...] _requirements: Tuple[TableRequirement, ...] @@ -255,12 +254,15 @@ def __init__(self, table: Table, autocommit: bool = False): table: The table that will be altered. autocommit: Option to automatically commit the changes when they are staged. """ - self.table_metadata = table.metadata self._table = table self._autocommit = autocommit self._updates = () self._requirements = () + @property + def table_metadata(self) -> TableMetadata: + return update_table_metadata(self._table.metadata, self._updates) + def __enter__(self) -> Transaction: """Start a transaction to update the table.""" return self @@ -286,8 +288,6 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ if type(new_requirement) not in existing_requirements: self._requirements = self._requirements + (new_requirement,) - self.table_metadata = update_table_metadata(self.table_metadata, updates) - if self._autocommit: self.commit_transaction() self._updates = () diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index f60ac1e3ee..4905c31bfb 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -360,7 +360,8 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta @_apply_table_update.register(AddPartitionSpecUpdate) def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: for spec in base_metadata.partition_specs: - if spec.spec_id == update.spec.spec_id: + # Only raise in case of a discrepancy + if spec.spec_id == update.spec.spec_id and spec != update.spec: raise ValueError(f"Partition spec with id {spec.spec_id} already exists: {spec}") metadata_updates: Dict[str, Any] = { @@ -525,6 +526,11 @@ def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _T @_apply_table_update.register(AddSortOrderUpdate) def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + for sort in base_metadata.sort_orders: + # Only raise in case of a discrepancy + if sort.order_id == update.sort_order.order_id and sort != update.sort_order: + raise ValueError(f"Sort-order with id {sort.order_id} already exists: {sort}") + context.add_update(update) return base_metadata.model_copy( update={ diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index 6a704839e2..fd975d81c9 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -154,7 +154,7 @@ def test_schema_evolution_via_transaction(catalog: Catalog) -> None: NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), ) - with pytest.raises(CommitFailedException) as exc_info: + with pytest.raises(CommitFailedException, match="Requirement failed: current schema id has changed: expected 2, found 3"): with tbl.transaction() as tx: # Start a new update schema_update = tx.update_schema() @@ -165,8 +165,6 @@ def test_schema_evolution_via_transaction(catalog: Catalog) -> None: # stage another update in the transaction schema_update.add_column("col_double", DoubleType()).commit() - assert "Requirement failed: current schema changed: expected id 2 != 3" in str(exc_info.value) - assert tbl.schema() == Schema( NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 8575b588b8..66ef908986 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1683,3 +1683,37 @@ def test_write_optional_list(session_catalog: Catalog) -> None: session_catalog.load_table(identifier).append(df_2) assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4 + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_evolve_and_write( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_evolve_and_write" + tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema()) + other_table = session_catalog.load_table(identifier) + + numbers = pa.array([1, 2, 3, 4], type=pa.int32()) + + with tbl.update_schema() as upd: + # This is not known by other_table + upd.add_column("id", IntegerType()) + + with other_table.transaction() as tx: + # Refreshes the underlying metadata, and the schema + other_table.refresh() + tx.append( + pa.Table.from_arrays( + [ + numbers, + ], + schema=pa.schema( + [ + pa.field("id", pa.int32(), nullable=True), + ] + ), + ) + ) + + assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers