Skip to main content
[Preprint]. 2024 May 25:2024.03.31.587283. [Version 2] doi: 10.1101/2024.03.31.587283

Listing 1:

PyG-style pseudocode for a multi-state GVP-GNN layer. We update node features for each conformer independently while maintaining permutation equivariance of the updated feature tensors along both the first (no. of nodes) and second (no. of conformations) axes.

1 class MultiGVPConv ( MessagePassing ):
2  ’’’ GVPConv for handling multiple conformations ’’’
3  def __init__ (self , ...) :
5   ...
6
7  def forward (self, x_s , x_v , edge_index, edge_attr ):
8
9   # stack scalar feats along axis 1:
10   # [ n_nodes , n_conf , d_s ] -> [n_nodes , n_conf * d_s]
11   x_s = x_s. view (x_s . shape [0] , x_s. shape [1] * x_s. shape [2])
12
13   # stack vector feat along axis 1:
14   # [ n_nodes , n_conf , d_v , 3] -> [ n_nodes , n_conf * d_v *3]
15   x_v = x_v. view (x_v . shape [0] , x_v. shape [1] * x_v. shape [2]*3)
16
17   # message passing and aggregation
18   message = self . propagate (
19    edge_index , s=x_s , v=x_v , edge_attr = edge_attr )
20
21   # split scalar and vector channels
22   return _split_multi ( message , d_s , d_v , n_conf )
23
24  def message (self , s_i , v_i , s_j , v_j , edge_attr ):
25
26   # unstack scalar feats:
27   # [ n_nodes , n_conf * d] -> [ n_nodes , n_conf , d_s]
28   s_i = s_i. view (s_i . shape [0] , s_i. shape [1]// d_s , d_s)
29   s_j = s_j. view (s_j . shape [0] , s_j. shape [1]// d_s , d_s)
30
31   # unstack vector feats:
32   # [ n_nodes , n_conf * d_v *3] -> [ n_nodes , n_conf , d_v , 3]
33   v_i = v_i. view (v_i . shape [0] , v_i. shape [1]//( d_v *3) , d_v , 3)
34   v_j = v_j. view (v_j . shape [0] , v_j. shape [1]//( d_v *3) , d_v , 3)
35
36   # message function for edge j-i
37   message = tuple_cat (( s_j , v_j), edge_attr , (s_i , v_i))
38   message = self . message_func ( message ) # GVP
39
40   # merge scalar and vector channels along axis 1
41   return _merge_multi (* message )
42
43 def _split_multi (x, d_s , d_v , n_conf ):
44  ’’’
45  Splits a merged representation of (s, v) back into a tuple.
46  ’’’
47  s = x[... , :-3 * d_v * n_conf ]. view (x. shape [0] , n_conf , d_s)
48  v = x[... , -3 * d_v * n_conf :]. view (x. shape [0] , n_conf , d_v , 3)
49  return s , v
50
51 def _merge_multi(s, v):
52  ’’’
53  Merges a tuple (s, v) into a single ‘torch.Tensor‘,
54  where the vector channels are flattened and
55  appended to the scalar channels.
56  ’’’
57  # s: [ n_nodes , n_conf , d] -> [ n_nodes , n_conf * d_s ]
58  s = s. view (s. shape [0] , s. shape [1] * s. shape [2])
59  # v: [ n_nodes , n_conf , d, 3] -> [ n_nodes , n_conf * d_v *3]
60  v = v. view (v. shape [0] , v. shape [1] * v. shape [2]*3)
61  return torch .cat ([s, v], −1)