Skip to content

Support for tensor outputs in TensorNet #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 117 additions & 5 deletions tests/test_equivariance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from torchmdnet.models.model import create_model
from utils import load_example_args

from utils import load_example_args, create_example_batch
from torchmdnet.models.output_modules import OutputModel
import pytorch_lightning as pl
import pytest

def test_scalar_invariance():
torch.manual_seed(1234)
Expand All @@ -23,27 +25,137 @@ def test_scalar_invariance():
torch.testing.assert_allclose(y, y_rot)


def test_vector_equivariance():
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_vector_equivariance(dtype):
torch.manual_seed(1234)
rotate = torch.tensor(
[
[0.9886788, -0.1102370, 0.1017945],
[0.1363630, 0.9431761, -0.3030248],
[-0.0626055, 0.3134752, 0.9475304],
]
)
).to(dtype)

model = create_model(
load_example_args(
"equivariant-transformer",
prior_model=None,
output_model="VectorOutput",
dtype=dtype,
)
)
z = torch.ones(100, dtype=torch.long)
pos = torch.randn(100, 3)
pos = torch.randn(100, 3).to(dtype)
batch = torch.arange(50, dtype=torch.long).repeat_interleave(2)

y = model(z, pos, batch)[0]
y_rot = model(z, pos @ rotate, batch)[0]
torch.testing.assert_allclose(y @ rotate, y_rot)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_tensornet_energy_invariance(dtype):
torch.manual_seed(1234)
pl.seed_everything(1234)

# create model and sample batch
args = load_example_args(
"tensornet",
remove_prior=True,
output_model="Scalar",
derivative=True,
dtype=dtype,
)
model = create_model(args)
natoms = 10
z = torch.ones(natoms, dtype=torch.long)
pos = torch.randn(natoms, 3).to(dtype)
batch = torch.zeros_like(z)
pos.to(dtype)
batch = torch.zeros_like(batch)
# run step
y, _ = model(z, pos, batch)
alpha = torch.rand(1).to(dtype) * 2 * 3.141592653589793
beta = torch.rand(1).to(dtype) * 2 * 3.141592653589793
gamma = torch.rand(1).to(dtype) * 2 * 3.141592653589793
Rx = torch.tensor( [ [1, 0, 0], [0, torch.cos(alpha), -torch.sin(alpha)], [0, torch.sin(alpha), torch.cos(alpha)] ])
Ry = torch.tensor( [ [torch.cos(beta), 0, torch.sin(beta)], [0, 1, 0], [-torch.sin(beta), 0, torch.cos(beta)] ] )
Rz = torch.tensor( [ [torch.cos(gamma), -torch.sin(gamma), 0], [torch.sin(gamma), torch.cos(gamma), 0], [0, 0, 1] ] )
rotate = (Rx @ Ry @ Rz).to(dtype)
y_rot, _ = model(z, pos @ rotate, batch)
torch.testing.assert_allclose(y, y_rot, rtol=1e-13 if dtype == torch.float64 else 1e-6, atol= 0)


from torch_scatter import scatter
class TensorOutput(OutputModel):
""" Output model for tensor properties.
Only compatible with TensorNet

"""
def __init__(
self,
hidden_channels,
activation="silu",
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float
):
super(TensorOutput, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
self.reset_parameters()

def reduce(self, input, batch):
I, A, S = input
I = scatter(I.sum(-3), batch, dim=0, reduce=self.reduce_op)
A = scatter(A.sum(-3), batch, dim=0, reduce=self.reduce_op)
S = scatter(S.sum(-3), batch, dim=0, reduce=self.reduce_op)
return I+A+S

def reset_parameters(self):
pass

def pre_reduce(self, x, v, z, pos, batch):
return v

def post_reduce(self, x):
return x

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_tensornet_equivariance(dtype):
torch.manual_seed(1234)
pl.seed_everything(1234)

# create model and sample batch
args = load_example_args(
"tensornet",
remove_prior=True,
output_model="Scalar",
derivative=False,
dtype=dtype,
)
model = create_model(args)
model.output_model = TensorOutput(args["embedding_dimension"],
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=args["dtype"])

natoms = 10
z = torch.ones(natoms, dtype=torch.long)
pos = torch.randn(natoms, 3).to(dtype)
batch = torch.zeros_like(z)
pos.to(dtype)
batch = torch.zeros_like(batch)
# run step
X = model(z, pos, batch)[0]

alpha = torch.rand(1).to(dtype) * 2 * 3.141592653589793
beta = torch.rand(1).to(dtype) * 2 * 3.141592653589793
gamma = torch.rand(1).to(dtype) * 2 * 3.141592653589793
Rx = torch.tensor( [ [1, 0, 0], [0, torch.cos(alpha), -torch.sin(alpha)], [0, torch.sin(alpha), torch.cos(alpha)] ])
Ry = torch.tensor( [ [torch.cos(beta), 0, torch.sin(beta)], [0, 1, 0], [-torch.sin(beta), 0, torch.cos(beta)] ] )
Rz = torch.tensor( [ [torch.cos(gamma), -torch.sin(gamma), 0], [torch.sin(gamma), torch.cos(gamma), 0], [0, 0, 1] ] )
rotate = (Rx @ Ry @ Rz).to(dtype)

Xrot = model(z, pos @ rotate, batch)[0]
torch.testing.assert_allclose(rotate.t()@(X @ rotate), Xrot, rtol=5e-13 if dtype == torch.float64 else 5e-5, atol= 0)
14 changes: 8 additions & 6 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def __init__(
self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model).to(dtype=dtype)

self.derivative = derivative

self.standardize = mean is not None and std is not None
if (mean is None and std is not None) or (mean is not None and std is None):
raise ValueError("Either both mean and std must be given or none of them.")
mean = torch.scalar_tensor(0) if mean is None else mean
self.register_buffer("mean", mean.to(dtype=dtype))
std = torch.scalar_tensor(1) if std is None else std
Expand Down Expand Up @@ -261,7 +263,7 @@ def forward(
x = self.output_model.pre_reduce(x, v, z, pos, batch)

# scale by data standard deviation
if self.std is not None:
if self.standardize:
x = x * self.std

# apply atom-wise prior model
Expand All @@ -273,17 +275,17 @@ def forward(
x = self.output_model.reduce(x, batch)

# shift by data mean
if self.mean is not None:
if self.standardize:
x = x + self.mean

# apply output model after reduction
y = self.output_model.post_reduce(x)
x = self.output_model.post_reduce(x)

# apply molecular-wise prior model
if self.prior_model is not None:
for prior in self.prior_model:
y = prior.post_reduce(y, z, pos, batch, extra_args)

y = prior.post_reduce(x, z, pos, batch, extra_args)
# compute gradients with respect to coordinates
if self.derivative:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
Expand Down
7 changes: 3 additions & 4 deletions torchmdnet/models/output_modules.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from abc import abstractmethod, ABCMeta
from torch_scatter import scatter
from typing import Optional
from typing import Optional, Tuple, Union
from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock
from torchmdnet.utils import atomic_masses
from torch_scatter import scatter
import torch
from torch import nn
from torch import nn, Tensor


__all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent"]
Expand All @@ -30,7 +30,6 @@ def reduce(self, x, batch):
def post_reduce(self, x):
return x


class Scalar(OutputModel):
def __init__(
self,
Expand Down Expand Up @@ -58,7 +57,7 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)

def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
def pre_reduce(self, x, v: Union[Tensor, Tuple[Tensor,Tensor,Tensor], None], z, pos, batch):
return self.output_network(x)


Expand Down
7 changes: 3 additions & 4 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def forward(
batch: Tensor,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:

) -> Tuple[Tensor, Optional[Tuple[Tensor,Tensor,Tensor]], Tensor, Tensor, Tensor]:
# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch)
# This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]
Expand All @@ -207,7 +206,7 @@ def forward(
x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
x = self.out_norm(x)
x = self.act(self.linear((x)))
return x, None, z, pos, batch
return x, (I, A, S), z, pos, batch


class TensorEmbedding(MessagePassing):
Expand Down Expand Up @@ -402,7 +401,7 @@ def forward(self, X, edge_index, edge_weight, edge_attr):
A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
dX = I + A + S
X = X + dX + dX**2
X = X + dX + torch.matrix_power(dX,2)
return X

def message(self, I_j, A_j, S_j, edge_attr):
Expand Down
5 changes: 4 additions & 1 deletion torchmdnet/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def forward(
atom_mask = z > self.remove_threshold
x = x[atom_mask]
if v is not None:
v = v[atom_mask]
if isinstance(v, tuple):
v = tuple(vi[atom_mask] for vi in v)
else:
v = v[atom_mask]
z = z[atom_mask]
pos = pos[atom_mask]
batch = batch[atom_mask]
Expand Down