diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index 9381a62d..baef50f1 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -1,7 +1,9 @@ import torch from torchmdnet.models.model import create_model -from utils import load_example_args - +from utils import load_example_args, create_example_batch +from torchmdnet.models.output_modules import OutputModel +import pytorch_lightning as pl +import pytest def test_scalar_invariance(): torch.manual_seed(1234) @@ -23,7 +25,8 @@ def test_scalar_invariance(): torch.testing.assert_allclose(y, y_rot) -def test_vector_equivariance(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_vector_equivariance(dtype): torch.manual_seed(1234) rotate = torch.tensor( [ @@ -31,19 +34,128 @@ def test_vector_equivariance(): [0.1363630, 0.9431761, -0.3030248], [-0.0626055, 0.3134752, 0.9475304], ] - ) + ).to(dtype) model = create_model( load_example_args( "equivariant-transformer", prior_model=None, output_model="VectorOutput", + dtype=dtype, ) ) z = torch.ones(100, dtype=torch.long) - pos = torch.randn(100, 3) + pos = torch.randn(100, 3).to(dtype) batch = torch.arange(50, dtype=torch.long).repeat_interleave(2) y = model(z, pos, batch)[0] y_rot = model(z, pos @ rotate, batch)[0] torch.testing.assert_allclose(y @ rotate, y_rot) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_tensornet_energy_invariance(dtype): + torch.manual_seed(1234) + pl.seed_everything(1234) + + # create model and sample batch + args = load_example_args( + "tensornet", + remove_prior=True, + output_model="Scalar", + derivative=True, + dtype=dtype, + ) + model = create_model(args) + natoms = 10 + z = torch.ones(natoms, dtype=torch.long) + pos = torch.randn(natoms, 3).to(dtype) + batch = torch.zeros_like(z) + pos.to(dtype) + batch = torch.zeros_like(batch) + # run step + y, _ = model(z, pos, batch) + alpha = torch.rand(1).to(dtype) * 2 * 3.141592653589793 + beta = torch.rand(1).to(dtype) * 2 * 3.141592653589793 + gamma = torch.rand(1).to(dtype) * 2 * 3.141592653589793 + Rx = torch.tensor( [ [1, 0, 0], [0, torch.cos(alpha), -torch.sin(alpha)], [0, torch.sin(alpha), torch.cos(alpha)] ]) + Ry = torch.tensor( [ [torch.cos(beta), 0, torch.sin(beta)], [0, 1, 0], [-torch.sin(beta), 0, torch.cos(beta)] ] ) + Rz = torch.tensor( [ [torch.cos(gamma), -torch.sin(gamma), 0], [torch.sin(gamma), torch.cos(gamma), 0], [0, 0, 1] ] ) + rotate = (Rx @ Ry @ Rz).to(dtype) + y_rot, _ = model(z, pos @ rotate, batch) + torch.testing.assert_allclose(y, y_rot, rtol=1e-13 if dtype == torch.float64 else 1e-6, atol= 0) + + +from torch_scatter import scatter +class TensorOutput(OutputModel): + """ Output model for tensor properties. + Only compatible with TensorNet + + """ + def __init__( + self, + hidden_channels, + activation="silu", + allow_prior_model=True, + reduce_op="sum", + dtype=torch.float + ): + super(TensorOutput, self).__init__( + allow_prior_model=allow_prior_model, reduce_op=reduce_op + ) + self.reset_parameters() + + def reduce(self, input, batch): + I, A, S = input + I = scatter(I.sum(-3), batch, dim=0, reduce=self.reduce_op) + A = scatter(A.sum(-3), batch, dim=0, reduce=self.reduce_op) + S = scatter(S.sum(-3), batch, dim=0, reduce=self.reduce_op) + return I+A+S + + def reset_parameters(self): + pass + + def pre_reduce(self, x, v, z, pos, batch): + return v + + def post_reduce(self, x): + return x + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_tensornet_equivariance(dtype): + torch.manual_seed(1234) + pl.seed_everything(1234) + + # create model and sample batch + args = load_example_args( + "tensornet", + remove_prior=True, + output_model="Scalar", + derivative=False, + dtype=dtype, + ) + model = create_model(args) + model.output_model = TensorOutput(args["embedding_dimension"], + activation=args["activation"], + reduce_op=args["reduce_op"], + dtype=args["dtype"]) + + natoms = 10 + z = torch.ones(natoms, dtype=torch.long) + pos = torch.randn(natoms, 3).to(dtype) + batch = torch.zeros_like(z) + pos.to(dtype) + batch = torch.zeros_like(batch) + # run step + X = model(z, pos, batch)[0] + + alpha = torch.rand(1).to(dtype) * 2 * 3.141592653589793 + beta = torch.rand(1).to(dtype) * 2 * 3.141592653589793 + gamma = torch.rand(1).to(dtype) * 2 * 3.141592653589793 + Rx = torch.tensor( [ [1, 0, 0], [0, torch.cos(alpha), -torch.sin(alpha)], [0, torch.sin(alpha), torch.cos(alpha)] ]) + Ry = torch.tensor( [ [torch.cos(beta), 0, torch.sin(beta)], [0, 1, 0], [-torch.sin(beta), 0, torch.cos(beta)] ] ) + Rz = torch.tensor( [ [torch.cos(gamma), -torch.sin(gamma), 0], [torch.sin(gamma), torch.cos(gamma), 0], [0, 0, 1] ] ) + rotate = (Rx @ Ry @ Rz).to(dtype) + + Xrot = model(z, pos @ rotate, batch)[0] + torch.testing.assert_allclose(rotate.t()@(X @ rotate), Xrot, rtol=5e-13 if dtype == torch.float64 else 5e-5, atol= 0) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 27992e73..2b1502a6 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -214,7 +214,9 @@ def __init__( self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model).to(dtype=dtype) self.derivative = derivative - + self.standardize = mean is not None and std is not None + if (mean is None and std is not None) or (mean is not None and std is None): + raise ValueError("Either both mean and std must be given or none of them.") mean = torch.scalar_tensor(0) if mean is None else mean self.register_buffer("mean", mean.to(dtype=dtype)) std = torch.scalar_tensor(1) if std is None else std @@ -261,7 +263,7 @@ def forward( x = self.output_model.pre_reduce(x, v, z, pos, batch) # scale by data standard deviation - if self.std is not None: + if self.standardize: x = x * self.std # apply atom-wise prior model @@ -273,17 +275,17 @@ def forward( x = self.output_model.reduce(x, batch) # shift by data mean - if self.mean is not None: + if self.standardize: x = x + self.mean # apply output model after reduction - y = self.output_model.post_reduce(x) + x = self.output_model.post_reduce(x) # apply molecular-wise prior model if self.prior_model is not None: for prior in self.prior_model: - y = prior.post_reduce(y, z, pos, batch, extra_args) - + y = prior.post_reduce(x, z, pos, batch, extra_args) + # compute gradients with respect to coordinates if self.derivative: grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)] diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index d3d7e1b8..8443d31b 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -1,11 +1,11 @@ from abc import abstractmethod, ABCMeta from torch_scatter import scatter -from typing import Optional +from typing import Optional, Tuple, Union from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock from torchmdnet.utils import atomic_masses from torch_scatter import scatter import torch -from torch import nn +from torch import nn, Tensor __all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent"] @@ -30,7 +30,6 @@ def reduce(self, x, batch): def post_reduce(self, x): return x - class Scalar(OutputModel): def __init__( self, @@ -58,7 +57,7 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.output_network[2].weight) self.output_network[2].bias.data.fill_(0) - def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch): + def pre_reduce(self, x, v: Union[Tensor, Tuple[Tensor,Tensor,Tensor], None], z, pos, batch): return self.output_network(x) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 97974864..ee05d164 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -187,8 +187,7 @@ def forward( batch: Tensor, q: Optional[Tensor] = None, s: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: - + ) -> Tuple[Tensor, Optional[Tuple[Tensor,Tensor,Tensor]], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch) # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor] @@ -207,7 +206,7 @@ def forward( x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) x = self.act(self.linear((x))) - return x, None, z, pos, batch + return x, (I, A, S), z, pos, batch class TensorEmbedding(MessagePassing): @@ -402,7 +401,7 @@ def forward(self, X, edge_index, edge_weight, edge_attr): A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + dX**2 + X = X + dX + torch.matrix_power(dX,2) return X def message(self, I_j, A_j, S_j, edge_attr): diff --git a/torchmdnet/models/wrappers.py b/torchmdnet/models/wrappers.py index 565e9455..0de62b6a 100644 --- a/torchmdnet/models/wrappers.py +++ b/torchmdnet/models/wrappers.py @@ -50,7 +50,10 @@ def forward( atom_mask = z > self.remove_threshold x = x[atom_mask] if v is not None: - v = v[atom_mask] + if isinstance(v, tuple): + v = tuple(vi[atom_mask] for vi in v) + else: + v = v[atom_mask] z = z[atom_mask] pos = pos[atom_mask] batch = batch[atom_mask]