-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathmodel.py
304 lines (267 loc) · 11.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import re
from typing import Optional, List, Tuple, Dict
import torch
from torch.autograd import grad
from torch import nn, Tensor
from torch_scatter import scatter
from pytorch_lightning.utilities import rank_zero_warn
from torchmdnet.models import output_modules
from torchmdnet.models.wrappers import AtomFilter
from torchmdnet.models.utils import dtype_mapping
from torchmdnet import priors
import warnings
def create_model(args, prior_model=None, mean=None, std=None):
"""Create a model from the given arguments.
See :func:`get_args` in scripts/train.py for a description of the arguments.
Parameters
----------
args (dict): Arguments for the model.
prior_model (nn.Module, optional): Prior model to use. Defaults to None.
mean (torch.Tensor, optional): Mean of the training data. Defaults to None.
std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None.
Returns
-------
nn.Module: An instance of the TorchMD_Net model.
"""
args["dtype"] = "float32" if "dtype" not in args else args["dtype"]
args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"]
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
num_rbf=args["num_rbf"],
rbf_type=args["rbf_type"],
trainable_rbf=args["trainable_rbf"],
activation=args["activation"],
cutoff_lower=args["cutoff_lower"],
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
max_num_neighbors=args["max_num_neighbors"],
dtype=args["dtype"]
)
# representation network
if args["model"] == "graph-network":
from torchmdnet.models.torchmd_gn import TorchMD_GN
is_equivariant = False
representation_model = TorchMD_GN(
num_filters=args["embedding_dimension"],
aggr=args["aggr"],
neighbor_embedding=args["neighbor_embedding"],
**shared_args
)
elif args["model"] == "transformer":
from torchmdnet.models.torchmd_t import TorchMD_T
is_equivariant = False
representation_model = TorchMD_T(
attn_activation=args["attn_activation"],
num_heads=args["num_heads"],
distance_influence=args["distance_influence"],
neighbor_embedding=args["neighbor_embedding"],
**shared_args,
)
elif args["model"] == "equivariant-transformer":
from torchmdnet.models.torchmd_et import TorchMD_ET
is_equivariant = True
representation_model = TorchMD_ET(
attn_activation=args["attn_activation"],
num_heads=args["num_heads"],
distance_influence=args["distance_influence"],
neighbor_embedding=args["neighbor_embedding"],
**shared_args,
)
elif args["model"] == "tensornet":
from torchmdnet.models.tensornet import TensorNet
# Setting is_equivariant to False to enforce the use of Scalar output module instead of EquivariantScalar
is_equivariant = False
representation_model = TensorNet(
equivariance_invariance_group=args["equivariance_invariance_group"],
**shared_args,
)
else:
raise ValueError(f'Unknown architecture: {args["model"]}')
# atom filter
if not args["derivative"] and args["atom_filter"] > -1:
representation_model = AtomFilter(representation_model, args["atom_filter"])
elif args["atom_filter"] > -1:
raise ValueError("Derivative and atom filter can't be used together")
# prior model
if args["prior_model"] and prior_model is None:
# instantiate prior model if it was not passed to create_model (i.e. when loading a model)
prior_model = create_prior_models(args)
# create output network
output_prefix = "Equivariant" if is_equivariant else ""
output_model = getattr(output_modules, output_prefix + args["output_model"])(
args["embedding_dimension"],
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=args["dtype"],
)
# combine representation and output network
model = TorchMD_Net(
representation_model,
output_model,
prior_model=prior_model,
mean=mean,
std=std,
derivative=args["derivative"],
dtype=args["dtype"],
)
return model
def load_model(filepath, args=None, device="cpu", **kwargs):
ckpt = torch.load(filepath, map_location="cpu")
if args is None:
args = ckpt["hyper_parameters"]
for key, value in kwargs.items():
if not key in args:
warnings.warn(f"Unknown hyperparameter: {key}={value}")
args[key] = value
model = create_model(args)
state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
# The following are for backward compatibility with models created when atomref was
# the only supported prior.
if 'prior_model.initial_atomref' in state_dict:
state_dict['prior_model.0.initial_atomref'] = state_dict['prior_model.initial_atomref']
del state_dict['prior_model.initial_atomref']
if 'prior_model.atomref.weight' in state_dict:
state_dict['prior_model.0.atomref.weight'] = state_dict['prior_model.atomref.weight']
del state_dict['prior_model.atomref.weight']
model.load_state_dict(state_dict)
return model.to(device)
def create_prior_models(args, dataset=None):
"""Parse the prior_model configuration option and create the prior models."""
prior_models = []
if args['prior_model']:
prior_model = args['prior_model']
prior_names = []
prior_args = []
if not isinstance(prior_model, list):
prior_model = [prior_model]
for prior in prior_model:
if isinstance(prior, dict):
for key, value in prior.items():
prior_names.append(key)
if value is None:
prior_args.append({})
else:
prior_args.append(value)
else:
prior_names.append(prior)
prior_args.append({})
if 'prior_args' in args:
prior_args = args['prior_args']
if not isinstance(prior_args, list):
prior_args = [prior_args]
for name, arg in zip(prior_names, prior_args):
assert hasattr(priors, name), (
f"Unknown prior model {name}. "
f"Available models are {', '.join(priors.__all__)}"
)
# initialize the prior model
prior_models.append(getattr(priors, name)(dataset=dataset, **arg))
return prior_models
class TorchMD_Net(nn.Module):
"""The TorchMD_Net class combines a given representation model
(such as the equivariant transformer), an output model (such as
the scalar output module) and a prior model (such as the atomref
prior), producing a Module that takes as input a series of atoms
features and outputs a scalar value (i.e energy for each
batch/molecule) and, derivative is True, the negative of its derivative
with respect to the positions (i.e forces for each atom).
"""
def __init__(
self,
representation_model,
output_model,
prior_model=None,
mean=None,
std=None,
derivative=False,
dtype=torch.float32,
):
super(TorchMD_Net, self).__init__()
self.representation_model = representation_model.to(dtype=dtype)
self.output_model = output_model.to(dtype=dtype)
if not output_model.allow_prior_model and prior_model is not None:
prior_model = None
rank_zero_warn(
(
"Prior model was given but the output model does "
"not allow prior models. Dropping the prior model."
)
)
if isinstance(prior_model, priors.base.BasePrior):
prior_model = [prior_model]
self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model).to(dtype=dtype)
self.derivative = derivative
self.standardize = mean is not None and std is not None
if (mean is None and std is not None) or (mean is not None and std is None):
raise ValueError("Either both mean and std must be given or none of them.")
mean = torch.scalar_tensor(0) if mean is None else mean
self.register_buffer("mean", mean.to(dtype=dtype))
std = torch.scalar_tensor(1) if std is None else std
self.register_buffer("std", std.to(dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
self.representation_model.reset_parameters()
self.output_model.reset_parameters()
if self.prior_model is not None:
for prior in self.prior_model:
prior.reset_parameters()
def forward(
self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""Compute the output of the model.
Args:
z (Tensor): Atomic numbers of the atoms in the molecule. Shape (N,).
pos (Tensor): Atomic positions in the molecule. Shape (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape (N,).
q (Tensor, optional): Atomic charges in the molecule. Shape (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
"""
assert z.dim() == 1 and z.dtype == torch.long
batch = torch.zeros_like(z) if batch is None else batch
if self.derivative:
pos.requires_grad_(True)
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(z, pos, batch, q=q, s=s)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
# scale by data standard deviation
if self.standardize:
x = x * self.std
# apply atom-wise prior model
if self.prior_model is not None:
for prior in self.prior_model:
x = prior.pre_reduce(x, z, pos, batch, extra_args)
# aggregate atoms
x = self.output_model.reduce(x, batch)
# shift by data mean
if self.standardize:
x = x + self.mean
# apply output model after reduction
x = self.output_model.post_reduce(x)
# apply molecular-wise prior model
if self.prior_model is not None:
for prior in self.prior_model:
y = prior.post_reduce(x, z, pos, batch, extra_args)
# compute gradients with respect to coordinates
if self.derivative:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
dy = grad(
[y],
[pos],
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
)[0]
if dy is None:
raise RuntimeError("Autograd returned None for the force prediction.")
return y, -dy
# TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180)
return y, None