1 |
class MultiGVPConv(MessagePassing):
|
2 |
“‘GVPConv for handling multiple conformations”’
|
3 |
|
4 |
def __init__(self, …):
|
5 |
…
|
6 |
|
7 |
def forward(self, x_s, x_v, edge_index, edge_attr):
|
8 |
|
9 |
# stack scalar feats along axis 1:
|
10 |
# [n_nodes, n_conf, d_s] -> [n_nodes, n_conf * d_s]
|
11 |
x_s = x_s.view(x_s.shape[0], x_s.shape[1] * x_s.shape[2])
|
12 |
|
13 |
# stack vector feat along axis 1:
|
14 |
# [n_nodes, n_conf, d_v, 3] -> [n_nodes, n_conf * d_v*3]
|
15 |
x_v = x_v.view(x_v.shape[0], x_v.shape[1] * x_v.shape[2]*3)
|
16 |
|
17 |
# message passing and aggregation
|
18 |
message = self.propagate(
|
19 |
edge_index, s=x_s, v=x_v, edge_attr=edge_attr)
|
20 |
|
21 |
# split scalar and vector channels
|
22 |
return _split_multi(message, d_s, d_v, n_conf)
|
23 |
|
24 |
def message(self, s_i, v_i, s_j, v_j, edge_attr):
|
25 |
|
26 |
# unstack scalar feats:
|
27 |
# [n_nodes, n_conf * d] -> [n_nodes, n_conf, d_s]
|
28 |
s_i = s_i.view(s_i.shape[0], s_i.shape[1]//d_s, d_s)
|
29 |
s_j = s_j.view(s_j.shape[0], s_j.shape[1]//d_s, d_s)
|
30 |
|
31 |
# unstack vector feats:
|
32 |
# [n_nodes, n_conf * d_v*3] -> [n_nodes, n_conf, d_v, 3]
|
33 |
v_i = v_i.view(v_i.shape[0], v_i.shape[1]//(d_v*3), d_v, 3)
|
34 |
v_j = v_j.view(v_j.shape[0], v_j.shape[1]//(d_v*3), d_v, 3)
|
35 |
|
36 |
# message function for edge j-i
|
37 |
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
|
38 |
message = self.message_func(message) # GVP
|
39 |
|
40 |
# merge scalar and vector channels along axis 1
|
41 |
return _merge_multi(*message)
|
42 |
|
43 |
def _split_multi(x, d_s, d_v, n_conf):
|
44 |
“‘
|
45 |
Splits a merged representation of (s, v) back into a tuple.
|
46 |
“‘
|
47 |
s = x[…, :−3 * d_v * n_conf].view(x.shape[0], n_conf, d_s)
|
48 |
v = x[…, −3 * d_v * n_conf:].view(x.shape[0], n_conf, d_v, 3)
|
49 |
return s, v
|
50 |
|
51 |
def _merge_multi(s, v):
|
52 |
“‘
|
53 |
Merges a tuple (s, v) into a single ‘torch.Tensor’,
|
54 |
where the vector channels are flattened and
|
55 |
appended to the scalar channels.
|
56 |
“‘
|
57 |
# s: [n_nodes, n_conf, d] -> [n_nodes, n_conf * d_s]
|
58 |
s = s.view(s.shape[0], s.shape[1] * s.shape[2])
|
59 |
# v: [n_nodes, n_conf, d, 3] -> [n_nodes, n_conf * d_v*3]
|
60 |
v = v.view(v.shape[0], v.shape[1] * v.shape[2]*3)
|
61 |
return torch.cat([s, v], −1)
|