1 class MultiGVPConv ( MessagePassing ): |
2 ’’’ GVPConv for handling multiple conformations ’’’ |
3 def __init__ (self , ...) : |
5 ... |
6 |
7 def forward (self, x_s , x_v , edge_index, edge_attr ): |
8 |
9 # stack scalar feats along axis 1: |
10 # [ n_nodes , n_conf , d_s ] -> [n_nodes , n_conf * d_s] |
11 x_s = x_s. view (x_s . shape [0] , x_s. shape [1] * x_s. shape [2]) |
12 |
13 # stack vector feat along axis 1: |
14 # [ n_nodes , n_conf , d_v , 3] -> [ n_nodes , n_conf * d_v *3] |
15 x_v = x_v. view (x_v . shape [0] , x_v. shape [1] * x_v. shape [2]*3) |
16 |
17 # message passing and aggregation |
18 message = self . propagate ( |
19 edge_index , s=x_s , v=x_v , edge_attr = edge_attr ) |
20 |
21 # split scalar and vector channels |
22 return _split_multi ( message , d_s , d_v , n_conf ) |
23 |
24 def message (self , s_i , v_i , s_j , v_j , edge_attr ): |
25 |
26 # unstack scalar feats: |
27 # [ n_nodes , n_conf * d] -> [ n_nodes , n_conf , d_s] |
28 s_i = s_i. view (s_i . shape [0] , s_i. shape [1]// d_s , d_s) |
29 s_j = s_j. view (s_j . shape [0] , s_j. shape [1]// d_s , d_s) |
30 |
31 # unstack vector feats: |
32 # [ n_nodes , n_conf * d_v *3] -> [ n_nodes , n_conf , d_v , 3] |
33 v_i = v_i. view (v_i . shape [0] , v_i. shape [1]//( d_v *3) , d_v , 3) |
34 v_j = v_j. view (v_j . shape [0] , v_j. shape [1]//( d_v *3) , d_v , 3) |
35 |
36 # message function for edge j-i |
37 message = tuple_cat (( s_j , v_j), edge_attr , (s_i , v_i)) |
38 message = self . message_func ( message ) # GVP |
39 |
40 # merge scalar and vector channels along axis 1 |
41 return _merge_multi (* message ) |
42 |
43 def _split_multi (x, d_s , d_v , n_conf ): |
44 ’’’ |
45 Splits a merged representation of (s, v) back into a tuple. |
46 ’’’ |
47 s = x[... , :-3 * d_v * n_conf ]. view (x. shape [0] , n_conf , d_s) |
48 v = x[... , -3 * d_v * n_conf :]. view (x. shape [0] , n_conf , d_v , 3) |
49 return s , v |
50 |
51 def _merge_multi(s, v): |
52 ’’’ |
53 Merges a tuple (s, v) into a single ‘torch.Tensor‘, |
54 where the vector channels are flattened and |
55 appended to the scalar channels. |
56 ’’’ |
57 # s: [ n_nodes , n_conf , d] -> [ n_nodes , n_conf * d_s ] |
58 s = s. view (s. shape [0] , s. shape [1] * s. shape [2]) |
59 # v: [ n_nodes , n_conf , d, 3] -> [ n_nodes , n_conf * d_v *3] |
60 v = v. view (v. shape [0] , v. shape [1] * v. shape [2]*3) |
61 return torch .cat ([s, v], −1) |