Skip to content

Allow appending extra values to embedding vector #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def test_torchscript_dynamic_shapes(model_name, device):
grad_outputs=grad_outputs,
)[0]

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_extra_embedding(model_name, device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
args = load_example_args(model_name, remove_prior=True)
args["extra_embedding"] = "atomic"
model = create_model(args)
torch.jit.script(model).to(device=device)

#Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
Expand Down Expand Up @@ -227,3 +237,14 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("use_batch", [True, False])
def test_extra_embedding(model_name, use_batch):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, prior_model=None)
args["extra_embedding"] = ["atomic", "global"]
model = create_model(args)
batch = batch if use_batch else None
model(z, pos, batch=batch, extra_args={'atomic':torch.rand(6), 'global':torch.rand(2)})
21 changes: 19 additions & 2 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def create_model(args, prior_model=None, mean=None, std=None):
args["static_shapes"] = False
if "vector_cutoff" not in args:
args["vector_cutoff"] = False
if "extra_embedding" not in args:
extra_embedding = None
elif isinstance(args["extra_embedding"], str):
extra_embedding = [args["extra_embedding"]]
else:
extra_embedding = args["extra_embedding"]

shared_args = dict(
hidden_channels=args["embedding_dimension"],
Expand All @@ -57,6 +63,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
else None
),
dtype=dtype,
extra_embedding=extra_embedding
)

# representation network
Expand Down Expand Up @@ -370,7 +377,7 @@ def forward(
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the model.

Returns:
Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise.
Expand All @@ -380,9 +387,19 @@ def forward(

if self.derivative:
pos.requires_grad_(True)
if self.representation_model.extra_embedding is None:
extra_embedding_args = None
else:
assert extra_args is not None
extra_embedding_args = []
for arg in self.representation_model.extra_embedding:
t = extra_args[arg]
if t.shape != z.shape:
t = t[batch]
extra_embedding_args.append(t)
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(
z, pos, batch, box=box, q=q, s=s
z, pos, batch, box=box, q=q, s=s, extra_embedding_args=extra_embedding_args
)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
Expand Down
29 changes: 25 additions & 4 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import torch
from typing import Optional, Tuple
from typing import Optional, List, Tuple
from torch import Tensor, nn
from torchmdnet.models.utils import (
CosineCutoff,
Expand Down Expand Up @@ -120,6 +120,9 @@ class TensorNet(nn.Module):
(default: :obj:`True`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -139,6 +142,7 @@ def __init__(
check_errors=True,
dtype=torch.float32,
box_vecs=None,
extra_embedding=None
):
super(TensorNet, self).__init__()

Expand All @@ -163,6 +167,7 @@ def __init__(
self.activation = activation
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.extra_embedding = extra_embedding
act_class = act_class_mapping[activation]
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
Expand All @@ -176,6 +181,7 @@ def __init__(
trainable_rbf,
max_z,
dtype,
extra_embedding
)

self.layers = nn.ModuleList()
Expand Down Expand Up @@ -228,6 +234,7 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: Optional[List[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
Expand Down Expand Up @@ -258,7 +265,7 @@ def forward(
# Normalizing edge vectors by their length can result in NaNs, breaking Autograd.
# I avoid dividing by zero by setting the weight of self edges and self loops to 1
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_embedding_args)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr, q)
I, A, S = decompose_tensor(X)
Expand Down Expand Up @@ -287,6 +294,7 @@ def __init__(
trainable_rbf=False,
max_z=128,
dtype=torch.float32,
extra_embedding=None
):
super(TensorEmbedding, self).__init__()

Expand All @@ -297,6 +305,10 @@ def __init__(
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)
self.max_z = max_z
self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None
self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype)
self.act = activation()
self.linears_tensor = nn.ModuleList()
Expand All @@ -319,15 +331,23 @@ def reset_parameters(self):
self.distance_proj2.reset_parameters()
self.distance_proj3.reset_parameters()
self.emb.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.emb2.reset_parameters()
for linear in self.linears_tensor:
linear.reset_parameters()
for linear in self.linears_scalar:
linear.reset_parameters()
self.init_norm.reset_parameters()

def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor:
def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_embedding_args: Optional[List[Tensor]]) -> Tensor:
Z = self.emb(z)
if self.reshape_embedding is not None and extra_embedding_args is not None:
tensors = [Z]
for t in extra_embedding_args:
tensors.append(t.unsqueeze(1))
Z = torch.cat(tensors, dim=1)
Z = self.reshape_embedding(Z)
Zij = self.emb2(
Z.index_select(0, edge_index.t().reshape(-1)).view(
-1, self.hidden_channels * 2
Expand Down Expand Up @@ -362,8 +382,9 @@ def forward(
edge_weight: Tensor,
edge_vec_norm: Tensor,
edge_attr: Tensor,
extra_embedding_args: Optional[List[Tensor]]
) -> Tensor:
Zij = self._get_atomic_number_message(z, edge_index)
Zij = self._get_atomic_number_message(z, edge_index, extra_embedding_args)
Iij, Aij, Sij = self._get_tensor_messages(
Zij, edge_weight, edge_vec_norm, edge_attr
)
Expand Down
21 changes: 19 additions & 2 deletions torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from typing import Optional, Tuple
from typing import Optional, List, Tuple
import torch
from torch import Tensor, nn
from torchmdnet.models.utils import (
Expand Down Expand Up @@ -79,7 +79,9 @@ class TorchMD_ET(nn.Module):
(default: :obj:`False`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -102,6 +104,7 @@ def __init__(
box_vecs=None,
vector_cutoff=False,
dtype=torch.float32,
extra_embedding=None
):
super(TorchMD_ET, self).__init__()

Expand Down Expand Up @@ -133,10 +136,15 @@ def __init__(
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.dtype = dtype
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -181,6 +189,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -196,8 +206,15 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: Optional[List[Tensor]] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None and extra_embedding_args is not None:
tensors = [x]
for t in extra_embedding_args:
tensors.append(t.unsqueeze(1))
x = torch.cat(tensors, dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
# This assert must be here to convince TorchScript that edge_vec is not None
Expand Down
21 changes: 19 additions & 2 deletions torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from typing import Optional, Tuple
from typing import Optional, List, Tuple
import torch
from torch import Tensor, nn
from torchmdnet.models.utils import (
Expand Down Expand Up @@ -86,7 +86,9 @@ class TorchMD_GN(nn.Module):
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -107,6 +109,7 @@ def __init__(
aggr="add",
dtype=torch.float32,
box_vecs=None,
extra_embedding=None
):
super(TorchMD_GN, self).__init__()

Expand Down Expand Up @@ -136,10 +139,15 @@ def __init__(
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.aggr = aggr
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -184,6 +192,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -198,8 +208,15 @@ def forward(
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
extra_embedding_args: Optional[List[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None and extra_embedding_args is not None:
tensors = [x]
for t in extra_embedding_args:
tensors.append(t.unsqueeze(1))
x = torch.cat(tensors, dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, _ = self.distance(pos, batch, box)
edge_attr = self.distance_expansion(edge_weight)
Expand Down
Loading