Skip to content

Commit 13926ad

Browse files
committed
Fixed test failures
1 parent 58f298f commit 13926ad

File tree

5 files changed

+9
-4
lines changed

5 files changed

+9
-4
lines changed

torchmdnet/models/tensornet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def forward(
234234
box: Optional[Tensor] = None,
235235
q: Optional[Tensor] = None,
236236
s: Optional[Tensor] = None,
237-
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
237+
extra_embedding_args: Optional[Tuple[Tensor]] = None
238238
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
239239
# Obtain graph, with distances and relative position vectors
240240
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)

torchmdnet/models/torchmd_et.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def forward(
206206
box: Optional[Tensor] = None,
207207
q: Optional[Tensor] = None,
208208
s: Optional[Tensor] = None,
209-
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
209+
extra_embedding_args: Optional[Tuple[Tensor]] = None
210210
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
211211
x = self.embedding(z)
212212
if self.reshape_embedding is not None:

torchmdnet/models/torchmd_gn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def forward(
208208
box: Optional[Tensor] = None,
209209
s: Optional[Tensor] = None,
210210
q: Optional[Tensor] = None,
211-
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
211+
extra_embedding_args: Optional[Tuple[Tensor]] = None
212212
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
213213
x = self.embedding(z)
214214
if self.reshape_embedding is not None:

torchmdnet/models/torchmd_t.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def forward(
202202
box: Optional[Tensor] = None,
203203
s: Optional[Tensor] = None,
204204
q: Optional[Tensor] = None,
205-
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
205+
extra_embedding_args: Optional[Tuple[Tensor]] = None
206206
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
207207
x = self.embedding(z)
208208
if self.reshape_embedding is not None:

torchmdnet/optimize.py

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(self, model):
3333

3434
super().__init__()
3535
self.model = model
36+
self.extra_embedding = model.extra_embedding
3637

3738
self.neighbors = CFConvNeighbors(self.model.cutoff_upper)
3839

@@ -58,12 +59,16 @@ def forward(
5859
box: Optional[pt.Tensor] = None,
5960
q: Optional[pt.Tensor] = None,
6061
s: Optional[pt.Tensor] = None,
62+
extra_embedding_args: Optional[Tuple[pt.Tensor]] = None
6163
) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]:
6264

6365
assert pt.all(batch == 0)
6466
assert box is None, "Box is not supported"
6567

6668
x = self.model.embedding(z)
69+
if self.model.reshape_embedding is not None:
70+
x = pt.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
71+
x = self.model.reshape_embedding(x)
6772

6873
self.neighbors.build(pos)
6974
for inter, conv in zip(self.model.interactions, self.convs):

0 commit comments

Comments
 (0)