Abstract
Graph Neural Networks (GNNs) are a powerful tool for machine learning on graphs. GNNs combine node feature information with the graph structure by recursively passing neural messages along edges of the input graph. However, incorporating both graph structure and feature information leads to complex models and explaining predictions made by GNNs remains unsolved. Here we propose GnnExplainer, the first general, model-agnostic approach for providing interpretable explanations for predictions of any GNN-based model on any graph-based machine learning task. Given an instance, GnnExplainer identifies a compact subgraph structure and a small subset of node features that have a crucial role in GNN’s prediction. Further, GnnExplainer can generate consistent and concise explanations for an entire class of instances. We formulate GnnExplainer as an optimization task that maximizes the mutual information between a GNN’s prediction and distribution of possible subgraph structures. Experiments on synthetic and real-world graphs show that our approach can identify important graph structures as well as node features, and outperforms alternative baseline approaches by up to 43.0% in explanation accuracy. GnnExplainer provides a variety of benefits, from the ability to visualize semantically relevant structures to interpretability, to giving insights into errors of faulty GNNs.
1. Introduction
In many real-world applications, including social, information, chemical, and biological domains, data can be naturally modeled as graphs [9, 41, 49]. Graphs are powerful data representations but are challenging to work with because they require modeling of rich relational information as well as node feature information [45, 46]. To address this challenge, Graph Neural Networks (GNNs) have emerged as state-of-the-art for machine learning on graphs, due to their ability to recursively incorporate information from neighboring nodes in the graph, naturally capturing both graph structure and node features [16, 21, 40, 44].
Despite their strengths, GNNs lack transparency as they do not easily allow for a human-intelligible explanation of their predictions. Yet, the ability to understand GNN’s predictions is important and useful for several reasons: (i) it can increase trust in the GNN model, (ii) it improves model’s transparency in a growing number of decision-critical applications pertaining to fairness, privacy and other safety challenges [11], and (iii) it allows practitioners to get an understanding of the network characteristics, identify and correct systematic patterns of mistakes made by models before deploying them in the real world.
While currently there are no methods for explaining GNNs, recent approaches for explaining other types of neural networks have taken one of two main routes. One line of work locally approximates models with simpler surrogate models, which are then probed for explanations [25, 29, 30]. Other methods carefully examine models for relevant features and find good qualitative interpretations of high level features [6, 13, 27, 32] or identify influential input instances [23, 38]. However, these approaches fall short in their ability to incorporate relational information, the essence of graphs. Since this aspect is crucial for the success of machine learning on graphs, any explanation of GNN’s predictions should leverage rich relational information provided by the graph as well as node features.
Here we propose GnnExplainer, an approach for explaining predictions made by GNNs. GnnExplainer takes a trained GNN and its prediction(s), and it returns an explanation in the form of a small subgraph of the input graph together with a small subset of node features that are most influential for the prediction(s) (Figure 1). The approach is model-agnostic and can explain predictions of any GNN on any machine learning task for graphs, including node classification, link prediction, and graph classification. It handles single- as well as multi-instance explanations. In the case of single-instance explanations, GnnExplainer explains a GNN’s prediction for one particular instance (i.e., a node label, a new link, a graph-level label). In the case of multi-instance explanations, GnnExplainer provides an explanation that consistently explains a set of instances (e.g., nodes of a given class).
GnnExplainer specifies an explanation as a rich subgraph of the entire graph the GNN was trained on, such that the subgraph maximizes the mutual information with GNN’s prediction(s). This is achieved by formulating a mean field variational approximation and learning a real-valued graph mask which selects the important subgraph of the GNN’s computation graph. Simultaneously, GnnExplainer also learns a feature mask that masks out unimportant node features (Figure 1).
We evaluate GnnExplainer on synthetic as well as real-world graphs. Experiments show that GnnExplainer provides consistent and concise explanations of GNN’s predictions. On synthetic graphs with planted network motifs, which play a role in determining node labels, we show that GnnExplainer accurately identifies the subgraphs/motifs as well as node features that determine node labels outperforming alternative baseline approaches by up to 43.0% in explanation accuracy. Further, using two real-world datasets we show how GnnExplainer can provide important domain insights by robustly identifying important graph structures and node features that influence a GNN’s predictions. Specifically, using molecular graphs and social interaction networks, we show that GnnExplainer can identify important domain-specific graph structures, such as NO2 chemical groups or ring structures in molecules, and star structures in Reddit threads. Overall, experiments demonstrate that GnnExplainer provides consistent and concise explanations for GNN-based models for different machine learning tasks on graphs.
2. Related work
Although the problem of explaining GNNs is not well-studied, the related problems of interpretability and neural debugging received substantial attention in machine learning. At a high level, we can group those interpretability methods for non-graph neural networks into two main families.
Methods in the first family formulate simple proxy models of full neural networks. This can be done in a model-agnostic way, usually by learning a locally faithful approximation around the prediction, for example through linear models [29] or sets of rules, representing sufficient conditions on the prediction [3, 25, 47]. Methods in the second family identify important aspects of the computation, for example, through feature gradients [13, 43], backpropagation of neurons’ contributions to the input features [6, 31, 32], and counterfactual reasoning [19]. However, the saliency maps [43] produced by these methods have been shown to be misleading in some instances [2] and prone to issues like gradient saturation [31, 32]. These issues are exacerbated on discrete inputs such as graph adjacency matrices since the gradient values can be very large but only on very small intervals. Because of that, such approaches are not suitable for explaining predictions made by neural networks on graphs.
Instead of creating new, inherently interpretable models, post-hoc interpretability methods [1, 14, 15, 17, 23, 38] consider models as black boxes and then probe them for relevant information. However, no work has been done to leverage relational structures like graphs. The lack of methods for explaining predictions on graph-structured data is problematic, as in many cases, predictions on graphs are induced by a complex combination of nodes and paths of edges between them. For example, in some tasks, an edge is important only when another alternative path exists in the graph to form a cycle, and those two features, only when considered together, can accurately predict node labels [10, 12]. Their joint contribution thus cannot be modeled as a simple linear combinations of individual contributions.
Finally, recent GNN models augment interpretability via attention mechanisms [28, 33, 34]. However, although the learned edge attention values can indicate important graph structure, the values are the same for predictions across all nodes. Thus, this contradicts with many applications where an edge is essential for predicting the label of one node but not the label of another node. Furthermore, these approaches are either limited to specific GNN architectures or cannot explain predictions by jointly considering both graph structure and node feature information.
3. Formulating explanations for graph neural networks
Let G denote a graph on edges E and nodes V that are associated with d-dimensional node features , . Without loss of generality, we consider the problem of explaining a node classification task (see Section 4.4 for other tasks). Let f denote a label function on nodes f : V ↦ {1,…,C} that maps every node in V to one of C classes. The GNN model Φ is optimized on all nodes in the training set and is then used for prediction, i.e., to approximate f on new nodes.
3.1. Background on graph neural networks
At layer l, the update of GNN model Φ involves three key computations [4, 45, 46]. (1) First, the model computes neural messages between every pair of nodes. The message for node pair (vi, vj) is a function Msg of vi’s and vj’s representations and in the previous layer and of the relation rij between the nodes: . (2) Second, for each node vi, GNN aggregates messages from vi’s neighborhood and calculates an aggregated message Mi via an aggregation method Agg [16, 35]: , where is neighborhood of node vi whose definition depends on a particular GNN variant. (3) Finally, GNN takes the aggregated message along with vi’s representation from the previous layer, and it non-linearly transforms them to obtain vi’s representation at layer . The final embedding for node vi after L layers of computation is . Our GnnExplainer provides explanations for any GNN that can be formulated in terms of Msg, Agg, and Update computations.
3.2. GnnExplainer: Problem formulation
Our key insight is the observation that the computation graph of node v, which is defined by the GNN’s neighborhood-based aggregation (Figure 2), fully determines all the information the GNN uses to generate prediction at node v. In particular, v’s computation graph tells the GNN how to generate v’s embedding z. Let us denote that computation graph by Gc(v), the associated binary adjacency matrix by Ac(v) ∈ {0, 1}n×n, and the associated feature set by . The GNN model Φ learns a conditional distribution PΦ(Y|Gc, Xc), where Y is a random variable representing labels {1,…,C}, indicating the probability of nodes belonging to each of C classes.
A GNN’s prediction is given by , meaning that it is fully determined by the model Φ, graph structural information Gc(v), and node feature information Xc(v). In effect, this observation implies that we only need to consider graph structure Gc(v) and node features Xc(v) to explain (Figure 2A). Formally, GnnExplainer generates explanation for prediction as , where GS is a small subgraph of the computation graph. XS is the associated feature of GS, and is a small subset of node features (masked out by the mask F, i.e., ) that are most important for explaining (Figure 2B).
4. GnnExplainer
Next we describe our approach GnnExplainer. Given a trained GNN model Φ and a prediction (i.e., single-instance explanation, Sections 4.1 and 4.2) or a set of predictions (i.e., multi-instance explanations, Section 4.3), the GnnExplainer will generate an explanation by identifying a subgraph of the computation graph and a subset of node features that are most influential for the model Φ’s prediction. In the case of explaining a set of predictions, GnnExplainer will aggregate individual explanations in the set and automatically summarize it with a prototype. We conclude this section with a discussion on how GnnExplainer can be used for any machine learning task on graphs, including link prediction and graph classification (Section 4.4).
4.1. Single-instance explanations
Given a node v, our goal is to identify a subgraph GS ⊆ Gc and the associated features XS = {xj|vj ∈ GS} that are important for the GNN’s prediction . For now, we assume that XS is a small subset of d-dimensional node features; we will later discuss how to automatically determine which dimensions of node features need to be included in explanations (Section 4.2). We formalize the notion of importance using mutual information MI and formulate the GnnExplainer as the following optimization framework:
(1) |
For node v, MI quantifies the change in the probability of prediction when v’s computation graph is limited to explanation subgraph GS and its node features are limited to XS.
For example, consider the situation where vj ∈ Gc(vi), vj ≠ vi. Then, if removing vj from Gc(vi) strongly decreases the probability of prediction , the node vj is a good counterfactual explanation for the prediction at vi. Similarly, consider the situation where (vj, vk ) ∈ Gc(vi), vj, vk ≠ vi. Then, if removing an edge between vj and vk strongly decreases the probability of prediction then the absence of that edge is a good counterfactual explanation for the prediction at vi.
Examining Eq. (1), we see that the entropy term H(Y) is constant because Φ is fixed for a trained GNN. As a result, maximizing mutual information between the predicted label distribution Y and explanation (GS, XS) is equivalent to minimizing conditional entropy H(Y|G = GS, X = XS), which can be expressed as follows:
(2) |
Explanation for prediction is thus a subgraph GS that minimizes uncertainty of Φ when the GNN computation is limited to GS. In effect, GS maximizes probability of (Figure 2). To obtain a compact explanation, we impose a constraint on GS’s size as: |GS| ≤ KM, so that GS has at most Km nodes. In effect, this implies that GnnExplainer aims to denoise Gc by taking KM edges that give the highest mutual information with the prediction.
GnnExplainer’s optimization framework
Direct optimization of GnnExplainer’s objective is not tractable as Gc has exponentially many subgraphs GS that are candidate explanations for . We thus consider a fractional adjacency matrix1 for subgraphs GS, i.e., AS ∈ [0, 1]n×n, and enforce the subgraph constraint as: AS[j, k] ≤ Ac[j, k] for all j, k. This continuous relaxation can be interpreted as a variational approximation of distribution of subgraphs of Gc. In particular, if we treat as a random graph variable, the objective in Eq. (2) becomes:
(3) |
With convexity assumption, Jensen’s inequality gives the following upper bound:
(4) |
In practice, due to the complexity of neural networks, the convexity assumption does not hold. However, experimentally, we found that minimizing this objective with regularization often leads to a local minimum corresponding to high-quality explanations.
To tractably estimate , we use mean-field variational approximation and decompose into a multivariate Bernoulli distribution as: . This allows us to estimate the expectation with respect to the mean-field approximation, thereby obtaining AS in which (j, k)-th entry represents the expectation on whether edge (vj, vk) exists. We observed empirically that this approximation together with a regularizer for promoting discreteness [40] converges to good local minima despite the non-convexity of GNNs. The conditional entropy in Equation 4 can be optimized by replacing the to be optimized by a masking of the computation graph of adjacency matrix, Ac ⨀ σ(M), where denotes the mask that we need to learn, ⨀ denotes element-wise multiplication, and σ denotes the sigmoid that maps the mask to [0, 1]n×n.
In some applications, instead of finding an explanation in terms of model’s confidence, the users care more about “why does the trained model predict a certain class label”, or “how to make the trained model predict a desired class label”. We can modify the conditional entropy objective in Equation 4 with a cross entropy objective between the label class and the model prediction2. To answer these queries, a computationally efficient version of GnnExplainer’s objective, which we optimize using gradient descent, is as follows:
(5) |
The masking approach is also found in Neural Relational Inference [22], albeit with different motivation and objective. Lastly, we compute the element-wise multiplication of σ(M) and Ac and remove low values in M through thresholding to arrive at the explanation GS for the GNN model’s prediction at node v.
4.2. Joint learning of graph structural and node feature information
To identify what node features are most important for prediction , GnnExplainer learns a feature selector F for nodes in explanation GS. Instead of defining XS to consists of all node features, i.e., , GnnExplainer considers as a subset of features of nodes in GS, which are defined through a binary feature selector F ∈ {0, 1}d (Figure 2B):
(6) |
where has node features that are not masked out by F. Explanation (GS, XS) is then jointly optimized for maximizing the mutual information objective:
(7) |
which represents a modified objective function from Eq. (1) that considers structural and node feature information to generate an explanation for prediction .
Learning binary feature selector F
We specify as XS ⨀ F, where F acts as a feature mask that we need to learn. intuitively, if a particular feature is not important, the corresponding weights in GNN’s weight matrix take values close to zero. In effect, this implies that masking the feature out does not decrease predicted probability for . Conversely, if the feature is important then masking it out would decrease predicted probability. However, in some cases this approach ignores features that are important for prediction but take values close to zero. To address this issue we marginalize over all feature subsets and use a Monte Carlo estimate to sample from empirical marginal distribution for nodes in XS during training [48]. Further, we use a reparametrization trick [20] to backpropagate gradients in Eq. (7) to the feature mask F. In particular, to backpropagate through a d-dimensional random variable X we reparametrize X as: X = Z + (XS − Z) ⨀ F s.t. , where Z is a d-dimensional random variable sampled from the empirical distribution and KF is a parameter representing the maximum number of features to be kept in the explanation.
Integrating additional constraints into explanations
To impose further properties on the explanation we can extend GnnExplainer’s objective function in Eq. (7) with regularization terms. For example, we use element-wise entropy to encourage structural and node feature masks to be discrete. Further, GnnExplainer can encode domain-specific constraints through techniques like Lagrange multiplier of constraints or additional regularization terms. We include a number of regularization terms to produce explanations with desired properties. We penalize large size of the explanation by adding the sum of all elements of the mask paramters as the regularization term.
Finally, it is important to note that each explanation must be a valid computation graph. In particular, explanation (GS, XS) needs to allow GNN’s neural messages to flow towards node v such that GNN can make prediction . Importantly, GnnExplainer automatically provides explanations that represent valid computation graphs because it optimizes structural masks across entire computation graphs. Even if a disconnected edge is important for neural message-passing, it will not be selected for explanation as it cannot influence GNN’s prediction. In effect, this implies that the explanation GS tends to be a small connected subgraph.
4.3. Multi-instance explanations through graph prototypes
The output of a single-instance explanation (Sections 4.1 and 4.2) is a small subgraph of the input graph and a small subset of associated node features that are most influential for a single prediction. To answer questions like “How did a GNN predict that a given set of nodes all have label c?”, we need to obtain a global explanation of class c. Our goal here is to provide insight into how the identified subgraph for a particular node relates to a graph structure that explains an entire class. GnnExplainer can provide multi-instance explanations based on graph alignments and prototypes. Our approach has two stages:
First, for a given class c (or, any set of predictions that we want to explain), we first choose a reference node vc, for example, by computing the mean embedding of all nodes assigned to c. We then take explanation GS(vc) for reference vc and align it to explanations of other nodes assigned to class c. Finding optimal matching of large graphs is challenging in practice. However, the single-instance GnnExplainer generates small graphs (Section 4.2) and thus near-optimal pairwise graph matchings can be efficiently computed.
Second, we aggregate aligned adjacency matrices into a graph prototype Aproto using, for example, a robust median-based approach. Prototype Aproto gives insights into graph patterns shared between nodes that belong to the same class. One can then study prediction for a particular node by comparing explanation for that node’s prediction (i.e., returned by single-instance explanation approach) to the prototype (see Appendix for more information).
4.4. GnnExplainer model extensions
Any machine learning task on graphs
In addition to explaining node classification, GnnExplainer provides explanations for link prediction and graph classification with no change to its optimization algorithm. When predicting a link (vj, vk), GnnExplainer learns two masks XS(vj) and XS(vk) for both endpoints of the link. When classifying a graph, the adjacency matrix in Eq. (5) is the union of adjacency matrices for all nodes in the graph whose label we want to explain. However, note that in graph classification, unlike node classification, due to the aggregation of node embeddings, it is no longer true that the explanation GS is necessarily a connected subgraph. Depending on application, in some scenarios such as chemistry where explanation is a functional group and should be connected, one can extract the largest connected component as the explanation.
Any GNN model
Modern GNNs are based on message passing architectures on the input graph. The message passing computation graphs can be composed in many different ways and GnnExplainer can account for all of them. Thus, GnnExplainer can be applied to: Graph Convolutional Networks [21], Gated Graph Sequence Neural Networks [26], Jumping Knowledge Networks [36], Attention Networks [33], Graph Networks [4], GNNs with various node aggregation schemes [7, 5, 18, 16, 40, 39, 35], Line-Graph Nns [8], position-aware GNN [42], and many other GNN architectures.
Computational complexity
The number of parameters in GnnExplainer’s optimization depends on the size of computation graph Gc for node v whose prediction we aim to explain. In particular, Gc(v)’s adjacency matrix Ac(v) is equal to the size of the mask M, which needs to be learned by GnnExplainer. However, since computation graphs are typically relatively small, compared to the size of exhaustive L-hop neighborhoods (e.g., 2–3 hop neighborhoods [21], sampling-based neighborhoods [39], neighborhoods with attention [33]), GnnExplainer can effectively generate explanations even when input graphs are large.
5. Experiments
We begin by describing the graphs, alternative baseline approaches, and experimental setup. We then present experiments on explaining GNNs for node classification and graph classification tasks. Our qualitative and quantitative analysis demonstrates that GnnExplainer is accurate and effective in identifying explanations, both in terms of graph structure and node features.
Synthetic datasets
We construct four kinds of node classification datasets (Table 1). (1) In BA-Shapes, we start with a base Barabási-Albert (BA) graph on 300 nodes and a set of 80 five-node “house”-structured network motifs, which are attached to randomly selected nodes of the base graph. The resulting graph is further perturbed by adding 0.1N random edges. Nodes are assigned to 4 classes based on their structural roles. In a house-structured motif, there are 3 types of roles: the top, middle and bottom node of the house. Therefore there are 4 different classes, corresponding to nodes at the top, middle, bottom of houses, and nodes that do not belong to a house. (2) BA-Community dataset is a union of two BA-Shapes graphs. Nodes have normally distributed feature vectors and are assigned to one of 8 classes based on their structural roles and community memberships. (3) In Tree-Cycles, we start with a base 8-level balanced binary tree and 80 six-node cycle motifs, which are attached to random nodes of the base graph. (4) Tree-Grid is the same as Tree-Cycles except that 3-by-3 grid motifs are attached to the base tree graph in place of cycle motifs.
Table 1:
Explanation accuracy | ||||
---|---|---|---|---|
| ||||
Att | 0.815 | 0.739 | 0.824 | 0.612 |
| ||||
Grad | 0.882 | 0.750 | 0.905 | 0.667 |
| ||||
GnnExplainer | 0.925 | 0.836 | 0.948 | 0.875 |
Real-world datasets
We consider two graph classification datasets: (1) Mutag is a dataset of 4,337 molecule graphs labeled according to their mutagenic effect on the Gram-negative bacterium S. typhimurium [10]. (2) Reddit-Binary is a dataset of 2,000 graphs, each representing an online discussion thread on Reddit. In each graph, nodes are users participating in a thread, and edges indicate that one user replied to another user’s comment. Graphs are labeled according to the type of user interactions in the thread: r/IAmA and r/AskReddit contain Question-Answer interactions, while r/TrollXChromosomes and r/atheism contain Online-Discussion interactions [37].
Alternative baseline approaches
Many explainability methods cannot be directly applied to graphs (Section 2). Nevertheless, we here consider the following alternative approaches that can provide insights into predictions made by GNNs: (1) Grad is a gradient-based method. We compute gradient of the GNN’s loss function with respect to the adjacency matrix and the associated node features, similar to a saliency map approach. (2) ATT is a graph attention GNN (GAT) [33] that learns attention weights for edges in the computation graph, which we use as a proxy measure of edge importance. While ATT does consider graph structure, it does not explain using node features and can only explain GAT models. Furthermore, in ATT it is not obvious which attention weights need to be used for edge importance, since a 1-hop neighbor of a node can also be a 2-hop neighbor of the same node due to cycles. Each edge’s importance is thus computed as the average attention weight across all layers.
Setup and implementation details
For each dataset, we first train a single GNN for each dataset, and use Grad and GnnExplainer to explain the predictions made by the GNN. Note that the ATT baseline requires using a graph attention architecture like GAT [33]. We thus train a separate GAT model on the same dataset and use the learned edge attention weights for explanation. Hyperparameters KM, KF control the size of subgraph and feature explanations respectively, which is informed by prior knowledge about the dataset. For synthetic datasets, we set KM to be the size of ground truth. On real-world datasets, we set KM = 10. We set KF = 5 for all datasets. We further fix our weight regularization hyperparameters across all node and graph classification experiments. We refer readers to the Appendix for more training details (Code and datasets are available at https://github.com/RexYing/gnn-model-explainer).
Results
We investigate questions: Does GnnExplainer provide sensible explanations? How do explanations compare to the ground-truth knowledge? How does GnnExplainer perform on various graph-based prediction tasks? Can it explain predictions made by different GNNs?
1). Quantitative analyses
Results on node classification datasets are shown in Table 1. We have ground-truth explanations for synthetic datasets and we use them to calculate explanation accuracy for all explanation methods. Specifically, we formalize the explanation problem as a binary classification task, where edges in the ground-truth explanation are treated as labels and importance weights given by explainability method are viewed as prediction scores. A better explainability method predicts high scores for edges that are in the ground-truth explanation, and thus achieves higher explanation accuracy. Results show that GnnExplainer outperforms alternative approaches by 17.1% on average. Further, GnnExplainer achieves up to 43.0% higher accuracy on the hardest Tree-Grid dataset.
2). Qualitative analyses
Results are shown in Figures 3–5. In a topology-based prediction task with no node features, e.g. BA-Shapes and Tree-Cycles, GnnExplainer correctly identifies network motifs that explain node labels, i.e. structural labels (Figure 3). As illustrated in the figures, house, cycle and tree motifs are identified by GnnExplainer but not by baseline methods. In Figure 4, we investigate explanations for graph classification task. In Mutag example, colors indicate node features, which represent atoms (hydrogen H, carbon C, etc). GnnExplainer correctly identifies carbon ring as well as chemical groups NH2 and NO2, which are known to be mutagenic [10].
Further, in Reddit-Binary example, we see that Question-Answer graphs (2nd row in Figure 4B) have 2–3 high degree nodes that simultaneously connect to many low degree nodes, which makes sense because in QA threads on Reddit we typically have 2–3 experts who all answer many different questions [24]. Conversely, we observe that discussion patterns commonly exhibit tree-like patterns (2nd row in Figure 4A), since a thread on Reddit is usually a reaction to a single topic [24]. On the other hand, Grad and Att methods give incorrect or incomplete explanations. For example, both baseline methods miss cycle motifs in Mutag dataset and more complex grid motifs in Tree-Grid dataset. Furthermore, although edge attention weights in Att can be interpreted as importance scores for message passing, the weights are shared across all nodes in input the graph, and as such Att fails to provide high quality single-instance explanations.
An essential criterion for explanations is that they must be interpretable, i.e., provide a qualitative understanding of the relationship between the input nodes and the prediction. such a requirement implies that explanations should be easy to understand while remaining exhaustive. This means that a GNN explainer should take into account both the structure of the underlying graph as well as the associated features when they are available. Figure 5 shows results of an experiment in which GnnExplainer jointly considers structural information as well as information from a small number of feature dimensions3. While GnnExplainer indeed highlights a compact feature representation in Figure 5, gradient-based approaches struggle to cope with the added noise, giving high importance scores to irrelevant feature dimensions.
Further experiments on multi-instance explanations using graph prototypes are in Appendix.
6. Conclusion
We present GnnExplainer, a novel method for explaining predictions of any GNN on any graph- based machine learning task without requiring modification of the underlying GNN architecture or re-training. We show how GnnExplainer can leverage recursive neighborhood-aggregation scheme of graph neural networks to identify important graph pathways as well as highlight relevant node feature information that is passed along edges of the pathways. While the problem of explainability of machine-learning predictions has received substantial attention in recent literature, our work is unique in the sense that it presents an approach that operates on relational structures—graphs with rich node features—and provides a straightforward interface for making sense out of GNN predictions, debugging GNN models, and identifying systematic patterns of mistakes.
Acknowledgments
Jure Leskovec is a Chan Zuckerberg Biohub investigator. We gratefully acknowledge the support of DARPA under FA865018C7880 (ASED) and MSC; NIH under No. U54EB020405 (Mobilize); ARO under No. 38796-Z8424103 (MURI); IARPA under No. 2017-17071900005 (HFC), NSF under No. OAC-1835598 (CINES) and HDR; Stanford Data Science Initiative, Chan Zuckerberg Biohub, JD.com, Amazon, Boeing, Docomo, Huawei, Hitachi, Observe, Siemens, UST Global. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright notation thereon. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views, policies, or endorsements, either expressed or implied, of DARPA, NIH, ONR, or the U.S. Government.
A. Multi-instance explanations
The problem of multi-instance explanations for graph neural networks is challenging and an important area to study.
Here we propose a solution based on GnnExplainer to find common components of explanations for a set of 10 explanations for 10 different instances in the same label class. More research in this area is necessary to design efficient Multi-instance explanation methods. The main challenges in practice is mainly due to the difficulty to perform graph alignment under noise and variances of node neighborhood structures for nodes in the same class. The problem is closely related to finding the maximum common subgraphs of explanation graphs, which is an NP-hard problem. In the following we introduces a neural approach to this problem. However, note that existing graph libraries (based on heuristics or integer programming relaxation) to find the maximal common subgraph of graphs can be employed to replace the neural components of the following procedure, when trying to identify and align with a prototype.
The output of a single-instance GnnExplainer indicates what graph structural and node feature information is important for a given prediction. To obtain an understanding of “why is a given set of nodes classified with label y”, we want to also obtain a global explanation of the class, which can shed light on how the identified structure for a given node is related to a prototypical structure unique for its label. To this end, we propose an alignment-based multi-instance GNNExPLAINER.
For any given class, we first choose a reference node. Intuitively, this node should be a prototypical node for the class. Such node can be found by computing the mean of the embeddings of all nodes in the class, and choose the node whose embedding is the closest to the mean. Alternatively, if one has prior knowledge about the important computation subgraph, one can choose one which matches most to the prior knowledge.
Given the reference node for class c, vc, and its associated important computation subgraph GS(vc), we align each of the identified computation subgraphs for all nodes in class c to the reference GS(vc). Utilizing the idea in the context of differentiable pooling [40], we use the a relaxed alignment matrix to find correspondence between nodes in an computation subgraph GS(v) and nodes in the reference computation subgraph GS(vc). Let Av and Xv be the adjacency matrix and the associated feature matrix of the to-be-aligned computation subgraph. Similarly let A* be the adjacency matrix and associated feature matrix of the reference computation subgraph. Then we optimize the relaxed alignment matrix , where nv is the number of nodes in GS(v), and n* is the number of nodes in GS(vc) as follows:
(8) |
The first term in Eq. (8) specifies that after alignment, the aligned adjacency for GS(v) should be as close to A* as possible. The second term in the equation specifies that the features should for the aligned nodes should also be close.
In practice, it is often non-trivial for the relaxed graph matching to find a good optimum for matching 2 large graphs. However, thanks to the single-instance explainer, which produces concise subgraphs for important message-passing, a matching that is close to the best alignment can be efficiently computed.
Prototype by alignment
We align the adjacency matrices of all nodes in class c, such that they are aligned with respect to the ordering defined by the reference adjacency matrix. We then use median to generate a prototype that is resistent to outliers, Aproto = median(Ai), where Ai is the aligned adjacency matrix representing explanation for i-th node in class c. Prototype Aproto allows users to gain insights into structural graph patterns shared between nodes that belong to the same class. Users can then investigate a particular node by comparing its explanation to the class prototype.
B. Experiments on multi-instance explanations and prototypes
In the context of multi-instance explanations, an explainer must not only highlight information locally relevant to a particular prediction, but also help emphasize higher-level correlations across instances. These instances can be related in arbitrary ways, but the most evident is class-membership. The assumption is that members of a class share common characteristics, and the model should help highlight them. For example, mutagenic compounds are often found to have certain characteristic functional groups that such NO2, a pair of Oxygen atoms together with a Nitrogen atom. A trained eye might notice that Figure 6 already hints at their presence. The evidence grows stronger when a prototype is generated by GnnExplainer, shown in Figure 6. The model is able to pick-up on this functional structure, and promote it as archetypal of mutagenic compounds.
C. Further implementation details
Training details
We use the Adam optimizer to train both the GNN and explaination methods. All GNN models are trained for 1000 epochs with learning rate 0.001, reaching accuracy of at least 85% for graph classification datasets, and 95% for node classification datasets. The train/validation/test split is 80/10/10% for all datasets. In GnnExplainer, we use the same optimizer and learning rate, and train for 100 – 300 epochs. This is efficient since GnnExplainer only needs to be trained on a local computation graph with < 100 nodes.
Regularization
In addition to graph size constraint and graph laplacian constraint, we further impose the feature size constraint, which constrains that the number of unmasked features do not exceed a threshold. The regularization hyperparameters for subgraph size is 0.005; for laplacian is 0.5; for feature explanation is 0.1. The same values of hyperparameters are used across all experiments.
Subgraph extraction
To extract the explanation subgraph GS, we first compute the importance weights on edges (gradients for Grad baseline, attention weights for Att baseline, and masked adjacency for GnnExplainer). A threshold is used to remove low-weight edges, and identify the explanation subgraph GS. The ground truth explanations of all datasets are connected subgraphs. Therefore, we identify the explanation as the connected component containing the explained node in GS. For graph classification, we identify the explanation by the maximum connected component of GS. For all methods, we perform a search to find the maximum threshold such that the explanation is at least of size KM. When multiple edges have tied importance weights, all of them are included in the explanation.
Footnotes
For typed edges, we define where Ce is the number of edge types.
The label class is the predicted label class by the GNN model to be explained, when answering “why does the trained model predict a certain class label”. “how to make the trained model predict a desired class label” can be answered by using the ground-truth label class.
Feature explanations are shown for the two datasets with node features, i.e., Mutag and BA-Community.
References
- [1].Adadi A and Berrada M Peeking Inside the Black-Box: A Survey on Explainable Artificial Intelligence (XAI). IEEE Access, 6:52138–52160, 2018. [Google Scholar]
- [2].Adebayo J, Gilmer J, Muelly M, Goodfellow I, Hardt M, and Kim B Sanity checks for saliency maps. In NeurIPS, 2018. [Google Scholar]
- [3].Gethsiyal Augasta M and Kathirvalavakumar T Reverse Engineering the Neural Networks for Rule Extraction in Classification Problems. Neural Processing Letters, 35(2):131–150, April 2012. [Google Scholar]
- [4].Battaglia Peter W, Hamrick Jessica B, Bapst Victor, Sanchez-Gonzalez Alvaro, Zam- baldi Vinicius, Malinowski Mateusz, Tacchetti Andrea, Raposo David, Santoro Adam, Faulkner Ryan, et al. Relational inductive biases, deep learning, and graph networks. arXiv:1806.01261, 2018. [Google Scholar]
- [5].Chen J, Zhu J, and Song L Stochastic training of graph convolutional networks with variance reduction. In ICML, 2018. [Google Scholar]
- [6].Chen Jianbo, Song Le, Wainwright Martin J, and Jordan Michael I. Learning to explain: An information-theoretic perspective on model interpretation. arXivpreprint arXiv:1802.07814, 2018. [Google Scholar]
- [7].Chen Jie, Ma Tengfei, and Xiao Cao. Fastgcn: fast learning with graph convolutional networks via importance sampling. In ICLR, 2018. [Google Scholar]
- [8].Chen Z, Li L, and Bruna J Supervised community detection with line graph neural networks. In ICLR, 2019. [Google Scholar]
- [9].Cho E, Myers S, and Leskovec J Friendship and mobility: user movement in location-based social networks. In KDD, 2011. [Google Scholar]
- [10].Debnath A et al. Structure-activity relationship of mutagenic aromatic and heteroaromatic nitro compounds. correlation with molecular orbital energies and hydrophobicity. Journal of Medicinal Chemistry, 34(2):786–797, 1991. [DOI] [PubMed] [Google Scholar]
- [11].Doshi-Velez F and Kim B Towards A Rigorous Science of Interpretable Machine Learning. 2017. arXiv: 1702.08608. [Google Scholar]
- [12].Duvenaud D et al. Convolutional networks on graphs for learning molecular fingerprints. In NIPS, 2015. [Google Scholar]
- [13].Erhan D, Bengio Y, Courville A, and Vincent P Visualizing higher-layer features of a deep network. University of Montreal, 1341(3):1, 2009. [Google Scholar]
- [14].Fisher A, Rudin C, and Dominici F All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. January 2018. arXiv: 1801.01489. [Google Scholar]
- [15].Guidotti R et al. A Survey of Methods for Explaining Black Box Models. ACM Comput. Surv, 51(5):93:1–93:42, 2018. [Google Scholar]
- [16].Hamilton W, Ying Z, and Leskovec J Inductive representation learning on large graphs. In NIPS, 2017. [Google Scholar]
- [17].Hooker G Discovering additive structure in black box functions. In KDD, 2004. [Google Scholar]
- [18].Huang WB, Zhang T, Rong Y, and Huang J Adaptive sampling towards fast graph representation learning. In NeurIPS, 2018. [Google Scholar]
- [19].Kang Bo, Lijffijt Jefrey, and De Bie Tijl. Explaine: An approach for explaining network embedding-based link predictions. arXiv:1904.12694, 2019. [Google Scholar]
- [20].Kingma Diederik P and Welling Max. Auto-encoding variational bayes. In NeurIPS, 2013. [Google Scholar]
- [21].Kipf TN and Welling M Semi-supervised classification with graph convolutional networks. In ICLR, 2016. [Google Scholar]
- [22].Kipf Thomas, Fetaya Ethan, Wang Kuan-Chieh, Welling Max, and Zemel Richard. Neural relational inference for interacting systems. In ICML, 2018. [Google Scholar]
- [23].Koh PW and Liang P Understanding black-box predictions via influence functions. In ICML, 2017. [Google Scholar]
- [24].Kumar Srijan, William L Hamilton Jure Leskovec, and Jurafsky Dan. Community interaction and conflict on the web. In WWW, pages 933–943, 2018.
- [25].Lakkaraju H, Kamar E, Caruana R, and Leskovec J Interpretable & Explorable Approximations of Black Box Models, 2017.
- [26].Li Y, Tarlow D, Brockschmidt M, and Zemel R Gated graph sequence neural networks. arXiv:1511.05493, 2015. [Google Scholar]
- [27].Lundberg S and Lee Su-In. A Unified Approach to Interpreting Model Predictions. In NIPS, 2017. [Google Scholar]
- [28].Neil D et al. Interpretable Graph Convolutional Neural Networks for Inference on Noisy Knowledge Graphs. In ML4H Workshop at NeurIPS, 2018. [Google Scholar]
- [29].Ribeiro M, Singh S, and Guestrin C Why should i trust you?: Explaining the predictions of any classifier. In KDD, 2016. [Google Scholar]
- [30].Schmitz GJ, Aldrich C, and Gouws FS ANN-DT: an algorithm for extraction of decision trees from artificial neural networks. IEEE Transactions on Neural Networks, 1999. [DOI] [PubMed] [Google Scholar]
- [31].Shrikumar A, Greenside P, and Kundaje A Learning Important Features Through Propagating Activation Differences. In ICML, 2017. [Google Scholar]
- [32].Sundararajan M, Taly A, and Yan Q Axiomatic Attribution for Deep Networks. In ICML, 2017. [Google Scholar]
- [33].Velickovic P, Cucurull G, Casanova A, Romero A, Lib P, and Bengio Y Graph attention networks. In ICLR, 2018. [Google Scholar]
- [34].Xie T and Grossman J Crystal graph convolutional neural networks for an accurate and interpretable prediction of material properties. In Phys. Rev. Lett, 2018. [DOI] [PubMed] [Google Scholar]
- [35].Xu K, Hu W, Leskovec J, and Jegelka S How powerful are graph neural networks? In ICRL, 2019. [Google Scholar]
- [36].Xu K, Li C, Tian Y, Sonobe T, Kawarabayashi K, and Jegelka S Representation learning on graphs with jumping knowledge networks. In ICML, 2018. [Google Scholar]
- [37].Yanardag Pinar and Vishwanathan SVN. Deep graph kernels. In KDD, pages 1365–1374. ACM, 2015. [Google Scholar]
- [38].Yeh C, Kim J, Yen I, and Ravikumar P Representer point selection for explaining deep neural networks. In NeurIPS, 2018. [Google Scholar]
- [39].Ying R, He R, Chen K, Eksombatchai P, Hamilton W, and Leskovec J Graph convolutional neural networks for web-scale recommender systems. In KDD, 2018. [Google Scholar]
- [40].Ying Z, You J, Morris C, Ren X, Hamilton W, and Leskovec J Hierarchical graph representation learning with differentiable pooling. In NeurIPS, 2018. [Google Scholar]
- [41].You J, Liu B, Ying R, Pande V, and Leskovec J Graph convolutional policy network for goal-directed molecular graph generation. 2018.
- [42].You J, Ying Rex, and Leskovec J Position-aware graph neural networks. In ICML, 2019. [Google Scholar]
- [43].Zeiler M and Fergus R Visualizing and Understanding Convolutional Networks. In ECCV. 2014. [Google Scholar]
- [44].Zhang M and Chen Y Link prediction based on graph neural networks. In NIPS, 2018. [Google Scholar]
- [45].Zhang Z, Peng C, and Zhu W Deep Learning on Graphs: A Survey. arXiv:1812,04202, 2018. [Google Scholar]
- [46].Zhou J, Cui G, Zhang Z, Yang C, Liu Z, and Sun M Graph Neural Networks: A Review of Methods and Applications. arXiv:1812,08434, 2018. [Google Scholar]
- [47].Zilke J, Loza Mencia E, and Janssen F DeepRED - Rule Extraction from Deep Neural Networks. In Discovery Science. Springer International Publishing, 2016. [Google Scholar]
- [48].Zintgraf L, Cohen T, Adel T, and Welling M Visualizing deep neural network decisions: Prediction difference analysis. In ICLR, 2017. [Google Scholar]
- [49].Zitnik M, Agrawal M, and Leskovec J Modeling polypharmacy side effects with graph convolutional networks. Bioinformatics, 34, 2018. [DOI] [PMC free article] [PubMed] [Google Scholar]