Abstract
Significant progress has been made using fMRI to characterize the brain changes that occur in ASD, a complex neuro-developmental disorder. However, due to the high dimensionality and low signal-to-noise ratio of fMRI, embedding informative and robust brain regional fMRI representations for both graph-level classification and region-level functional difference detection tasks between ASD and healthy control (HC) groups is difficult. Here, we model the whole brain fMRI as a graph, which preserves geometrical and temporal information and use a Graph Neural Network (GNN) to learn from the graph-structured fMRI data. We investigate the potential of including mutual information (MI) loss (Infomax), which is an unsupervised term encouraging large MI of each nodal representation and its corresponding graph-level summarized representation to learn a better graph embedding. Specifically, this work developed a pipeline including a GNN encoder, a classifier and a discriminator, which forces the encoded nodal representations to both benefit classification and reveal the common nodal patterns in a graph. We simultaneously optimize graph-level classification loss and Infomax. We demonstrated that Infomax graph embedding improves classification performance as a regularization term. Furthermore, we found separable nodal representations of ASD and HC groups in prefrontal cortex, cingulate cortex, visual regions, and other social, emotional and execution related brain regions. In contrast with GNN with classification loss only, the proposed pipeline can facilitate training more robust ASD classification models. Moreover, the separable nodal representations can detect the functional differences between the two groups and contribute to revealing new ASD biomarkers.
Keywords: fMRI, ASD Classification, Graph Embedding, Mutual Information Loss
1. INTRODUCTION
Autism spectrum disorder (ASD) affects the structure and function of the brain. Functional magnetic resonance imaging (fMRI) produces 4D spatial-temporal data describing functional activation but with very low signal-noise ratio (SNR). It can be used to characterize neural pathways and brain changes that occur in ASD. However, due to high dimension and low SNR, it is difficult to analyze fMRI. Here, we address the problem of embedding good fMRI representations for identifying ASD and detecting brain functional differences between ASD and healthy control (HC). To utilize the spatial-temporal information of fMRI, we represent the whole brain fMRI as a graph, where each brain region (ROI) is a node, the underlying connection can be calculated by fMRI correlation matrix and node features can be predetermined, hence preserving both geometrical and temporal information. The Graph Neural Network (GNN), a deep learning architecture to analyze graph structured data, has been used in ASD classification.1 In addition to improving ASD classification, one core objective of our work is to discover useful representations to detect brain regional differences between ASD vs HC. The simple idea explored here is to train a representation-learning function related to the end-goal task, which maximizes the mutual information (MI) between nodal representation and graph-level representation and minimizes the loss of the end-goal task. MI is notoriously difficult to compute, particularly in continuous and high dimensional settings. Fortunately, the recently proposed MINE2 enables effective computation of MI between high dimensional input/output pairs of a deep neural network, by training a statistics network as a classifier of samples coming from the joint distribution of two random variables and the product of their marginals. During training of a GNN, we simultaneously optimize the classification loss and Infomax loss,3 which maximizes the MI between local/global representation. In this way, we tune the suitability of learned representations for classification and detecting group-level regional functional differences. Results show the improvement of the classification task and reveal the functional differences between ASD and HC from the separable embedded brain regions encoded by the GNN.
Our contributions are summarized as follows:
We formulate ASD brain abnormal functional discovery as a node embedding problem using GNN.
To the best of our knowledge, we investigate deep graph infomax (DGI) for fMRI image analysis for the first time. The learnt representations benefit both summarizing nodes of interest, and downstream graph-level classification task.
We discover group-level brain ROI difference between ASD v.s. HC group and provide potential biomarkers for ASD diagnosis and treatment.
Paper structure:
In Section 2, we summarize transitional methods to analyze fMRI as a graph and GNNs techniques. In Section 3, we introduce the methods used for our study. Specifically, in Section 3.1, we propose how to covert fMRI images to graphs; in Section 3.2, we introduce the details of network architecture; and in Section 3.3, we develop the loss function. The experiments, results, and evaluation methods are presented in Section 4. We conclude the paper in Section 5.
2. RELATED WORK
2.1. Analyze fMRI As A Graph
Usually, brain can be parcellated by atlas resulting a particular number of ROIs, saying R. Then mean fMRI time series of each ROI is extracted then calculate its correlation to the mean fMRI time series of other ROIs, resulting in a R × R correlation matrices for each subject. With the connectiviy matrices at hand, we hope to reduce the dimensionality of the signal to a few hundreds components, which can contain some cognitive information. The functional network of the brain can be modelled as a graph in which each node is a brain region and the edges represent the direction and strength of the connection between those regions. There are several ways to measure the edge weights. These are generally divided into two groups. The first group of methods focus on using finding causality between different brain regions and thus deriving directed edges. These methods include Granger causality4 and dynamic casual modeling.5 The second group of methods is used to obtain undirected edges by using simpler statistical association measures to measure the intra-node connection strength. These are discussed in more detail in the “connectivity measures” section. Modelling the brain as a graph allows us to use graph theoretical metrics to study the functional network. Above all, to extract the whole-brain functional connectivity network of each subject, each ROI is seen as a network node and a measure of connectivity is used to connect these nodes.6 This connectivity measure aij must be able to quantify the relationship between the time series of ROI i and j. Correlation and mutual information metrics have been extensively used for this purpose.7
2.2. Graph Neural Network
Different from the traditional graph learning methods mentioned above, GNN is the state-of-the-art deep learning methods for graph-structured data analysis. GNNs combine node features, connection patterns, and graph struc-ture by using a neural network to embed node information and passing through edges in the graph. GNN methods aim to generalize the traditional convolutional neural networks (CNN) used in image classification to graph structures. For example, Graph Convolutional Neural Network (GCN)8 proposes spectral graph convolutions and provide approximated fast computation.9 proposes another variant for general graph classification from spatial space. Due to its convincing performance and high interpretability, GNN has been a widely applied graph analysis method recently, especially on fMRI analysis.1,10,11
3. METHODOLOGY
3.1. Data Definition and Notations
On a labeled graph set , the general graph classification problem is to learn a classifier that maps Gi to its label yi. In practice, the Gi is usually given as a triple G = (V, A, X) where V = {v1,… vN} is the set of N nodes, A = {aij}N×N is the set of edges with aij denoting the edge weight, and is the set of node features.
Suppose each brain is parcellated into N ROIs based on its T1 structural MRI. We define an undirected graph on the brain regions G = (V, A). Nodes V are associated with and , and D is the attribute dimensions of nodes feature. For node attributes, we concatenate handcrafted features: degree of connectivity, General Linear Model (GLM) coefficients, mean, and standard deviation of task-fMRI, and ROI center coordinates. A is calculated by the correlation of the mean fMRI time series in each ROI.
3.2. Graph Embedding
In addition to improving ASD classification, one core objective of our work is to discover useful representations to detect brain regional differences between ASD vs HC. The overview and each component of our architecture to achieve our goal are discussed in his section.
3.2.1. Overview of the pipeline
We present the pipeline in Figure 1. It contains three components: node encoder, graph classifier and data discriminator. Graph convolutional kernel (Section 3.2.2) will encode the input graph to a feature map , that reflects useful structure locally. Next, we summarize the node representation into a global feature by pooling and reading out (Section 3.2.3). Given a G, we will generate a negative graph G′, whose embedded node representation is H′. The corresponding positive pair and negative pair will be encouraged to be separated by a discriminator (Section 3.2.4).
Figure 1:
The flowchart of our proposed ASD classification and graph embedding architecture. The top row of the flowchart is a Graph Neural Network architecture to classify ASD and HC. The bottom row is a graph infomax pipeline to encourage better graph embedding. Here, (a) and (b) are positive samples; (c) and (d) are negative samples. (a)(c) (or (b)(d)) is a paired graph. The inputs of discriminator D is the summary vector generated from positive samples, paired with node embedded representation (H or H′). pair will have True (T) output from D, whereas will have False (F) output. The encoder, classifier and discriminator are trained simultaneously.
3.2.2. Encoder: Graph Convolutional Layer
Generally speaking, GNNs inductively learn a node rep-resentation by recursively aggregating and transforming the featurevectors of its neighboring nodes. Our encoder node embedding network is a L-layer supervised GraphSAGE architecture,9 which learns the embedding function mapping input nodes X to output H. The embedding function is based on the mean-pooling (MP) propagation rule as used in Hamilton et al.,9 where is the adjacency matrix with inserted self-loops and is its corresponding degree diagonal matrix with . Our encoder can be written as:
(1) |
where W is a learnable projection matrix and σ is sigmoid function.
3.2.3. Classifier: Pooling and Readout Layer
To aggregate the information of each node for the graph level classification, we use Dense hierarchical pooling (DHP12) to cluster nodes together. After each DHP, the number of nodes in the graph decreases. At the last level L, the pooling layer is performed by a hltering matrix .
(2) |
produces pooled nodes and adjacency matrix , which generate readout vector The final number of nodes Q is predefined. F was learned by another GraphSAGE convolutional layer optimized by a regularization loss Lreg = ∥Ap,FFT∥F, where ∥ · ∥F denotes the Frobenius norm. Readout vector will be submitted to a MLP for obtaining final classification outputs p, the probability of being an ASD subject.
3.2.4. Discriminator: Encouraging Good Representation
Following the intuition in Deep Graph Infomax,3 the good representation may not benefit from encoding counter information. In order to obtain a representation more suitable for classification, we maximize the average MI between the high-level representation and local aggregated embedding of each node, which favours encoding aspects of the data that are shared across the nodes and reduces noisy encoding.13 The graph-level summary vector can be as the input of discriminator, here σ is the logistic sigmoid nonlinearity. A discriminator is used as a proxy for maximizing the MI representing the probability scores assigned to the local-global pairs. We randomly sample an instance from the opposite class as the negative sample (X′, A′). The discriminator scores summary-node representation pairs by applying a simple bi-linear scoring function3
(3) |
where M is a learnable scoring matrix and σ is the logistic sigmoid nonlinearity, used to convert scores into probabilities of () being positive.
3.3. Loss function
In order to learn useful, predictive representations, the Infomax loss function L2 encourages nodes of the same graph to have similar representations, while enforcing that the representations of disparate nodes are highly distinct. In order to insure the performance of downstream classification, we use binary cross-entropy as the classification loss L1. Therefore, the loss function of our model is written as:
(4) |
Therefore, we can achieve graph classification and different node representations detection goals by end-to-end training manner.
4. EXPERIMENT AND RESULTS
4.1. Data Acquisition and Preprocessing
For the fMRI scans, subjects viewed point light animations of coherent and scrambled biological motion in a block design.14 We tested our method on a group of 75 ASD children and 43 age and IQ-matched healthy controls collected at Yale Child Study Center.1 The experimental paradigm features coherent and scrambled point-light animations created from motion capture data. The coherent biological motion depicts an adult male actor performing movements relevant to early childhood experiences.15 The random motion combine the trajectories of 16 randomly selected points from the coherent displays. Six biological motion clips and six scrambled motion clips were presented in an alternating-block design (24s per block). The fMRI data was preprocessed following the pipeline in Yang et al.16 The fMRI data was preprocessed using Freesurfer as follows: 1) motion correction using MCFLIRT, 2) interleaved slice timing correction, 3) BET brain extraction, 4) spatial smoothing (FWHM=5mm), and 5) high-pass temporal filtering. The functional and anatomical data were registered to the MNI152 standard brain atlas.17 The graph data was augmented as described in our previous work,1 resulting in 750 ASD graphs and 860 HC graphs. We split the data into 5 folds based on subjects. Four folds were used as training data and the left out fold was used for testing. Each node attribute , which include node degree, mean fMRI, std fMRI, ROI center coordinates and GLM parameters. Specifically, the GLM parameters of the “biopoint task” are: β1, coefficient of biological motion matrix; β3, coefficient of scrambled motion matrix; β2 and β4, coefficients of the previous two matrices’ derivatives.
4.2. Experiment and Results
We tested classifier performance on the Destrieux atlas18 (148 ROIs) using the proposed GNN with L1 and separately, to examine the advantage of including graph infomax loss L2. In our GNN setting, D = 10 and pooling ratios r = 0.5. We used the Adam optimizer with initial learning 0.001, then decreased it by a factor of 2 every 20 epochs. We trained the network 100 epochs for all of the splits and measured the instance classification by F-score (Table 1). We changed the architectures by tuning either two graph convolutional layers with kernel size (F, F) or one graph convolutional layer with kernel size (F). F was tested at 8 and 16. The regularization parameters are adjusted correspondingly to get the best performance.
Table 1:
Performance of different loss functions and GNN architectures (mean± std)
Loss + (conv-layer) | L1(16,16) | L1(8,8) | L1(16) | (16,16) | (8,8) | (16) |
F-score | 0.57±0.11 | 0.70±0.06 | 0.63±0.01 | 0.68±0.08 | 0.69±0.05 | 0.66±0.03 |
For notation convenience, we use (·) model and model to represent the model of certain GNN architecture and corresponding training loss. Under model (8, 8), we could not find obvious advantage of using . However, if we increase the encoders’ complexity to (16, 16), the L1 model became easily overfitted while model kept similar performance. This may indicate L2 can perform as regularization and restrain embedding from data noise. In (16) case, the L1 model was underhtted, while the model performed slightly better. It’s probably because L2 encourages encoding common nodal signals over the graph hence ignoring data noise or just because model had more trainable parameters.
After training, we extracted the nodal embedded vectors after the last Graph Convolutional Layer and used t-SNE19 to visualize the node presentations in 2D space shown in Fig. 2. Only with did we find linearly separable nodal representations of ASD and HC for certain regions. We visually examined all the nodal representation embeddings by L1 and verified they cannot be linearly separated into the two groups. We marked the regions whose Silhouette score20 was greater than 0.1 (resulting in 31 regions using (8, 8), shown in Fig. 3a) as the brain ROIs with functional difference between ASD and HC. We compared the results with GLM z-stats analysis using FSL21 (shown in Fig. 3b ). Our proposed method marked obvious prefrontal cortex, while GLM method did not highlight those regions. Both our method and GLM analysis highlighted cingulate cortex. These regions were indicated as ASD biomarkers in Yang et al.16 and Kaiser et al.14 Also, we used Neurosynth22 to decode the functional keywords associated with separable regions found by our methods, as shown in Fig. 4. The decoded functional keywords of our detected regions showed evidence that these regions might have social-, mental-, visual-related and default mind functional differences between ASD and HC group. Potentially, our proposed method can be used as a tool to identify new brain biomarkers for better understanding the underlying roots of ASD.
Figure 2:
Embedded representations of 24 out of 148 brain regions visualized by t-SNE. HC is colored in green and ASD is colored in red. The regions representations circled by blue box are linear separable.
Figure 3:
Highlight functional differences between ASD and HC.
Figure 4:
The correlation between the functional keyword and the regions in Figure 3a decoded by Neurosynth.
5. CONCLUSION
We applied GNN to identify ASD and designed a loss function to encourage better node representation and detect separable brain regions of ASD and HC. By incorporating mutual information of local and global representations, the proposed loss function improved classification performance in certain cases. The added L2 Infomax loss potentially regularizes the embedding of noisy fMRI and increases model robustness. By examining the embedded node representations, we found that ASD and HC had separable representations in regions related to default mode, social function, emotion regulation and visual function, etc. The finding is consistent with prior literature1, 14 and our approach could potentially discover new functional differences between ASD and HC. Overall, the proposed method provides an efficient and objective way of embedding ASD and HC brain graphs.
7. ACKNOWLEDGEMENTS
Data collection and sharing for this project was funded by the Autism Brain Imaging Data Exchange dataset (ABIDE).23 Parts of this research was supported by National Institutes of Health (NIH) [R01NS035193, R01MH100028].
Footnotes
DECLARATION OF COMPETING INTEREST
The authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper.
REFERENCES
- [1].Li X, Dvornek NC, Zhou Y, Zhuang J, Ventola P, and Duncan JS, “Graph neural network for interpreting task-fmri biomarkers,” in [International Conference on Medical Image Computing and Computer-Assisted Intervention], 485–493, Springer; (2019). [DOI] [PMC free article] [PubMed] [Google Scholar]
- [2].Belghazi MI et al. , “Mine: mutual information neural estimation,” ICML (2018). [Google Scholar]
- [3].Veličković P et al. , “Deep graph infomax,” ICLR (2019). [Google Scholar]
- [4].Granger CW, “Investigating causal relations by econometric models and cross-spectral methods,” Econometrica: Journal of the Econometric Society , 424–438 (1969). [Google Scholar]
- [5].Friston KJ, Harrison L, and Penny W, “Dynamic causal modelling,” Neuroimage 19(4), 1273–1302 (2003). [DOI] [PubMed] [Google Scholar]
- [6].Bullmore E and Sporns O, “Complex brain networks: graph theoretical analysis of structural and functional systems,” Nature reviews neuroscience 10(3), 186 (2009). [DOI] [PubMed] [Google Scholar]
- [7].Rubinov M and Sporns O, “Complex network measures of brain connectivity: uses and interpretations,” Neuroimage 52(3), 1059–1069 (2010). [DOI] [PubMed] [Google Scholar]
- [8].Kipf TN and Welling M, “Semi-supervised classification with graph convolutional networks,” arXiv preprint arXiv:1609.02907 (2016). [Google Scholar]
- [9].Hamilton WL, Ying R, and Leskovec J, “Inductive representation learning on large graphs,” in [NIPS], (2017). [Google Scholar]
- [10].Yan Y, Zhu J, Duda M, Solarz E, Sripada C, and Koutra D, “Groupinn: Grouping-based interpretable neural network for classification of limited, noisy brain data,” in [Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining], 772–782 (2019). [Google Scholar]
- [11].Yang H, Li X, Wu Y, Li S, Lu S, Duncan JS, Gee JC, and Gu S, “Interpretable multimodality embedding of cerebral cortex using attention graph network for identifying bipolar disorder,” (2019). [Google Scholar]
- [12].Ying Z et al. , “Hierarchical graph representation learning with differentiable pooling,” in [NeurlPS], 4805–4815 (2018). [Google Scholar]
- [13].Hjelm RD et al. , “Learning deep representations by mutual information estimation and maximization,” arXiv preprint arXiv:1808.06670 (2018). [Google Scholar]
- [14].Kaiser MD et al. , “Neural signatures of autism,” PNAS (2010). [DOI] [PMC free article] [PubMed] [Google Scholar]
- [15].Klin A, Lin DJ, Gorrindo P, Ramsay G, and Jones W, “Two-year-olds with autism orient to non-social contingencies rather than biological motion,” Nature 459(7244), 257 (2009). [DOI] [PMC free article] [PubMed] [Google Scholar]
- [16].Yang D et al. , “Brain responses to biological motion predict treatment outcome in young children with autism,” Translational psychiatry 6(11), e948 (2016). [DOI] [PMC free article] [PubMed] [Google Scholar]
- [17].Venkataraman A, Yang DY-J, Pelphrey KA, and Duncan JS, “Bayesian community detection in the space of group-level functional differences,” IEEE transactions on medical imaging 35(8), 1866–1882 (2016). [DOI] [PMC free article] [PubMed] [Google Scholar]
- [18].Destrieux C et al. , “Automatic parcellation of human cortical gyri and sulci using standard anatomical nomenclature,” Neuroimage 53(1), 1–15 (2010). [DOI] [PMC free article] [PubMed] [Google Scholar]
- [19].Maaten L. v. d. and Hinton G, “Visualizing data using t-sne,” Journal of machine learning research 9(Nov), 2579–2605 (2008). [Google Scholar]
- [20].Rousseeuw PJ, “Silhouettes: a graphical aid to the interpretation and validation of cluster analysis,” Journal of computational and applied mathematics 20, 53–65 (1987). [Google Scholar]
- [21].Jenkinson M et al. , “FSL,” Neuroimage (2012). [DOI] [PubMed] [Google Scholar]
- [22].Yarkoni T et al. , “Large-scale automated synthesis of human functional neuroimaging data,” Nature methods 8(8), 665 (2011). [DOI] [PMC free article] [PubMed] [Google Scholar]
- [23].Martino D et al. , “The autism brain imaging data exchange: towards a large-scale evaluation of the intrinsic brain architecture in autism,” Molecular psychiatry (2014). [DOI] [PMC free article] [PubMed] [Google Scholar]