Skip to main content
[Preprint]. 2024 Oct 6:2024.03.31.587283. [Version 3] doi: 10.1101/2024.03.31.587283

Listing 1:

PyG-style pseudocode for a multi-state GVP-GNN layer. We update node features for each conformational state independently while maintaining permutation equivariance of the updated feature tensors along both the first (no. of nodes) and second (no. of conformations) axes.

1
class MultiGVPConv(MessagePassing):
2
“‘GVPConv for handling multiple conformations”’
3
4
def __init__(self, …):
5
6
7
def forward(self, x_s, x_v, edge_index, edge_attr):
8
9
# stack scalar feats along axis 1:
10
# [n_nodes, n_conf, d_s] -> [n_nodes, n_conf * d_s]
11
x_s = x_s.view(x_s.shape[0], x_s.shape[1] * x_s.shape[2])
12
13
# stack vector feat along axis 1:
14
# [n_nodes, n_conf, d_v, 3] -> [n_nodes, n_conf * d_v*3]
15
x_v = x_v.view(x_v.shape[0], x_v.shape[1] * x_v.shape[2]*3)
16
17
# message passing and aggregation
18
message = self.propagate(
19
edge_index, s=x_s, v=x_v, edge_attr=edge_attr)
20
21
# split scalar and vector channels
22
return _split_multi(message, d_s, d_v, n_conf)
23
24
def message(self, s_i, v_i, s_j, v_j, edge_attr):
25
26
# unstack scalar feats:
27
# [n_nodes, n_conf * d] -> [n_nodes, n_conf, d_s]
28
s_i = s_i.view(s_i.shape[0], s_i.shape[1]//d_s, d_s)
29
s_j = s_j.view(s_j.shape[0], s_j.shape[1]//d_s, d_s)
30
31
# unstack vector feats:
32
# [n_nodes, n_conf * d_v*3] -> [n_nodes, n_conf, d_v, 3]
33
v_i = v_i.view(v_i.shape[0], v_i.shape[1]//(d_v*3), d_v, 3)
34
v_j = v_j.view(v_j.shape[0], v_j.shape[1]//(d_v*3), d_v, 3)
35
36
# message function for edge j-i
37
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
38
message = self.message_func(message) # GVP
39
40
# merge scalar and vector channels along axis 1
41
return _merge_multi(*message)
42
43
def _split_multi(x, d_s, d_v, n_conf):
44
“‘
45
Splits a merged representation of (s, v) back into a tuple.
46
“‘
47
s = x[…, :−3 * d_v * n_conf].view(x.shape[0], n_conf, d_s)
48
v = x[…, −3 * d_v * n_conf:].view(x.shape[0], n_conf, d_v, 3)
49
return s, v
50
51
def _merge_multi(s, v):
52
“‘
53
Merges a tuple (s, v) into a single ‘torch.Tensor’,
54
where the vector channels are flattened and
55
appended to the scalar channels.
56
“‘
57
# s: [n_nodes, n_conf, d] -> [n_nodes, n_conf * d_s]
58
s = s.view(s.shape[0], s.shape[1] * s.shape[2])
59
# v: [n_nodes, n_conf, d, 3] -> [n_nodes, n_conf * d_v*3]
60
v = v.view(v.shape[0], v.shape[1] * v.shape[2]*3)
61
return torch.cat([s, v], −1)