Abstract
Diagnosis prediction, a key factor in enhancing healthcare efficiency, remains a focal point in clinical decision support research. However, the time-series, sparse and multi-noise characteristics of electronic health record (EHR) data make it a great challenge. Existing methods commonly address these issues using RNNs and incorporating medical prior knowledge from medical knowledge bases, but they neglect the local spatial characteristics and spatial–temporal correlation of the data. Consequently, we propose MDPG, a diagnosis prediction model based on patient knowledge graphs. Initially, we represent the electronic visit records of patients as a patient-centered temporal knowledge graph, capturing the local spatial structure and temporal characteristics of the visit information. Subsequently, we design the spatial graph convolution block, temporal self-attention block, and spatial–temporal synchronous graph convolution block to capture the spatial, temporal, and spatial–temporal correlations embedded in them, respectively. Ultimately, we accomplish the prediction of patients’ future states through multi-label classification. We conduct comprehensive experiments on two real-world datasets independently and evaluate the results using visit-level precision@k and code-level accuracy@k metrics. The experimental results demonstrate that MDPG outperforms all baseline models, yielding the best performance.
Keywords: Medical knowledge graphs, Patient knowledge graphs, Healthcare representation learning, Diagnosis prediction, Patient risk prediction
Introduction
Amidst the ongoing evolution of hospital informatization, the electronic health record (EHR) system has emerged as a vital repository of patients’ clinical information, amassing a substantial volume of data on clinical diagnoses and treatment [1]. These data not only encompasses details about a series of diagnoses, procedures, medications, and other medical events that transpired during the clinical diagnosis and treatment of patients [2] but also meticulously record patients’ demographic information. In recent years, propelled by the swift advancement of artificial intelligence technology, patients’ EHR data have found extensive application in predictive healthcare tasks, yielding excellent results [3, 4], fostering the rapid development of personalized precision medicine. Notably, the diagnosis prediction task, foreseeing patients’ future diagnoses based on their historical visit records, stands as the cornerstone of personalized medicine, emerging as a prevalent research focal point in both industry and academia.
Electronic health record data is characterized by being time-series, high-dimensional, sparse, and noisy, posing significant challenges for diagnosis prediction. Three primary challenges arise: (1) How can EHR data be appropriately modeled to enhance semantic representation and comprehensively capture the embedded diagnosis and treatment knowledge? (2) EHR data not only encompasses abundant semantic information but also exhibits temporal characteristics. How can the temporal dependence of medical codes be addressed? (3) How can the robustness of models be enhanced, and the impact of data sparsity and noise on prediction be mitigated?
For modeling temporal EHR data, the majority of existing models resort to recurrent neural networks (RNNs), classical models designed for time series data, to capture the temporal properties inherent in EHR data. In Retain [5],an RNN and an EHR sequence with reverse time ordering are employed to predict future diagnoses, enabling an explanation of the contribution of each medical code observed in historical visits to the current visit. However, RNN models, constrained by forgetfulness limitations, often prioritize short-term memory [6, 7] and encounter gradient disappearance issues when confronted with long sequences. Consequently, the model performs suboptimally when patients possess a lengthy visit history. Additionally, Dipole [8] employs a bidirectional recurrent neural network (BRNN) [9] with attention to address the challenge of handling long sequences. Despite considering the temporal characteristics of EHR data, the aforementioned methods struggle to address issues related to data sparsity and noise.
Medical knowledge bases, such as the International Classification of Diseases (ICD), Clinical Classification Software (CCS) [10], and Systematized Nomenclature of Medical-Clinical Terms (SNOMED-CT) [11], embody meticulously established and authoritative medical consensus achieved through manual calibration. These knowledge bases not only encompass rich information about medical events pertaining to diagnoses, medications, and procedures but also exhibit a robust hierarchical structure. Consequently, researchers have commenced leveraging the prior knowledge within medical knowledge bases to furnish additional medical information to prediction models, thereby fortifying their robustness and alleviating the impact of data insufficiency and noise. GRAM [12] leverages medical knowledge and graph-based attention mechanism to acquire stable medical code representations, effectively addressing the issue of data sparsity in the absence of sufficient EHR data for training. However, GRAM solely incorporates prior medical knowledge in the process of learning medical code representations. Moreover, KAME [13] directly employs medical prior knowledge throughout the prediction process, encompassing the learning of code representations, generation of visit embeddings, and making predictions, ultimately enhancing the accuracy and interpretability of the model. Conversely, CAMP [14] captures the fine-grained developmental patterns of patients’ conditions using a collaborative attentional memory network. It also models the interplay between the patient’s background and historical visit records. Fundamentally, the medical knowledge base is a heterogeneous graph with medical concepts serving as nodes and parent–child relations as edges.
While the aforementioned models partially address the temporal relevance of EHR data, they fall short in fully leveraging the information in the medical knowledge base and neglect the equally crucial graph structure feature [15], namely, spatial relevance. Presently, graph neural networks (GNNs) have emerged as the preeminent solution for graph representation learning [16, 17]. Influenced by spatial–temporal graphs, GNDP [15] devises a diagnosis prediction model based on graph neural networks. This model employs spatial–temporal graphs to restructure patient EHR data and utilizes both temporal and spatial features to acquire resilient patient representations, enhancing prediction accuracy. Owing to the outstanding expressiveness of the graph structure, GNDP’s graph-based modeling of patient EHR data adeptly captures the spatial semantics inherent in EHR data. However, it still has the following shortcomings: (1) Solely relying on the medical knowledge base results in singular semantic relations that cannot incorporate patient background information [14]. Consequently, its semantic expressiveness is weak, leading to an incomplete portrayal of the semantic information within the EHR data and a lack of flexibility. (2) It overlooks the spatial–temporal correlation among visit records. As illustrated in Fig. 1, there is not only the spatial correlation represented by the black line, indicating the impact of each medical event on the patient’s state during a visit, and the temporal correlation represented by the purple line, indicating the self-influence of the patient’s state during different visits, but also the spatial–temporal correlation represented by the red line, which indicates the impact of the current patient’s state on the medical events during different visits. Despite GNDP addressing spatial correlation and temporal correlation, it neglects spatial–temporal correlation. (3) The attention mechanism significantly outperforms RNNs and CNNs in capturing long-range dependencies and parallelism [18], suggesting that employing attention to capture the temporal correlation of EHR data is a more reasonable approach.
Fig. 1.
The spatial, temporal and spatial–temporal dynamic correlation between patient visit records
Knowledge graphs, such as Freebase and Wikidata, are characterized by rich semantic expression and serve as the cornerstone of cognitive intelligence. Given the effectiveness of knowledge, the temporal knowledge graph has recently emerged as a prominent research topic [19]. Essentially, a temporal knowledge graph is a dynamic graph with the goal of modeling temporal knowledge and facilitating reasoning based on it. This introduces a novel approach for handling temporal dynamic data. Therefore, to better articulate the structure and semantics of EHR data and address the aforementioned challenges, in this paper, we introduce a novel multi-disease diagnosis prediction model, MDPG, based on patient knowledge graphs. This model aims to perform diagnosis prediction tasks by leveraging modeling approaches from temporal knowledge graphs. In contrast to existing methods that model patient EHR data as a collection of medical codes or a medical code graph based on medical knowledge bases, we fuse patient historical visit records, medical knowledge bases, and patient demographic information directly. This integration forms a patient-centered temporal knowledge graph that comprehensivelyand accurately portrays the medical information pertinent to the patient’s diagnosis and treatment process. Subsequently, we devise a prediction model utilizing the dynamic graph neural network DySAT [20] and the spatial–temporal simultaneous graph convolution STSGCN [21]. This model aims to capture the spatial, temporal, and spatial–temporal features of the patient knowledge graph.
We conduct comprehensive experiments on two real-world medical datasets. Our method achieves significantly better prediction accuracy in diagnosis prediction compared to other baseline methods, demonstrating its effectiveness and advancement. In summary, our main contributions are as follows:
We propose a novel multi-disease diagnosis prediction model based on graph neural networks, which can predict patients’ future health status by leveraging the historical visit information in their EHR data. The model considers not only the spatial and temporal correlations of the data but also the spatial–temporal correlation. This consideration significantly mitigates the impact of data insufficiency and noise on diagnosis prediction, leading to improved prediction performance and robustness.
We introduce an effective method for modeling patient EHR data. This method models patient visit records, medical prior knowledge, and patient demographic information as a patient-centered temporal knowledge graph. With the powerful semantic representation capability of the knowledge graph, it can not only comprehensively and finely represent diagnosis and treatment information at a local level using a multi-relational heterogeneous graph, enhancing semantic expression capabilities, but also represent the temporal correlation between the patient’s successive medical records through a dynamic graph.
The spatial–temporal synchronous convolution block is introduced to capture the spatial–temporal characteristics of EHR data. In contrast to existing models that learn patient representations from space and time separately, it simultaneously captures correlations among medical events from a spatial–temporal perspective to enhance prediction performance.
To mitigate the impact of long-range visit records, we employ self-attention instead of RNN to simply and efficiently learn the temporal features of EHR data. This approach not only captures the higher-order associations of temporal relations but also enhances the model’s concurrency performance.
Related work
In this section, we review related work on diagnosis prediction based on EHR data. After that, we present some work on temporal knowledge graphs and graph neural networks that are relevant to our work.
Diagnosis prediction based on EHR data
Deep learning techniques have greatly improved the performance of data mining tasks based on EHR data. Some typical tasks such as disease progression [22–25], electronic genotyping and phenotyping [2, 26–28], missing value mining [29, 30], unplanned admissions, and risk prediction [31, 32] have employed classical deep learning models such as recurrent neural networks (RNNs) and convolutional neural networks (CNNs) as the underlying framework to enhance performance. Diagnosis prediction stands out as one of the most important and challenging tasks among these applications. Its goal is to predict the future health status of patients based on their historical visit records, thereby assisting doctors in improving efficiency and reducing decision-related risks. The challenge lies in effectively modeling the information from patients’ historical visits and constructing prediction models with high robustness.
Med2Vec [3] is an early representative approach that has paved the way for new research ideas in diagnosis prediction. It is an unsupervised method designed to learn low-dimensional representations of medical codes (i.e., diagnosis codes, procedure codes, and medication codes) and apply them to diagnosis prediction. However, This method does not account for the temporal relationship among medical codes across visit records and cannot model the temporal properties of the data. Based on this, many excellent methods have been subsequently developed to further address these challenges. These methods can be categorized into two main groups. One category consists of approaches that utilize RNNs to model the temporal correlation of patient visit records, thereby improving prediction performance. The other category comprises approaches that alleviate data insufficiency and noise, enhancing model robustness by incorporating medical prior knowledge.
RETAIN [5] and Dipole [8] are representative examples of the first category of approaches. Recognizing that the most recent health status of patients carries more information than previous ones, RETAIN uses a RNN with reverse time-ordered EHR sequences to model patient visit information for binary prediction tasks. Furthermore, it provides a reasonable interpretation of the contributions of medical codes present in historical visits to the current prediction. However, the forgetfulness limitation of RNN hampers the model’s ability to handle long-range visit records. Additionally, RETAIN employs location-based attention to predict future diagnoses. It predicts the visit information at the subsequent time node by calculating attention weights based on the position of the preceding visit time node, yet it ignores the relationships among all visit information. To address this limitation, Dipole utilizes an attention-based bidirectional recurrent neural network (BRNN) to model patient visit information and learn a low-dimensional representation for diagnosis prediction. The BRNN mitigates the impact of long sequences by integrating all possible past and future input information. The model also proposes three attention mechanisms which are location-based, general and concatenation-based to assign different weights to all visits of each patient.
All of the aforementioned models are susceptible to data insufficiency and noise. This is another challenge faced in diagnostic prediction [12, 13], and it is difficult to address using algorithms. Therefore, researchers prefer the second category of models based on medical knowledge bases, which has been experimentally demonstrated to effectively enhance the performance of RNN-based models through the incorporation of medical prior knowledge.
GRAM [12] utilizes medical knowledge bases to encode relationships between hierarchical clinical structures and medical concepts. It optimizes the learning of low-dimensional representations for medical codes through a graph-based attention mechanism, considering the frequency of medical concepts in EHR data and their ancestors in the medical knowledge bases, thereby mitigating the challenge of learning representations for rare medical codes. GRAM uses medical prior knowledge solely for learning low-dimensional representations of medical codes. Subsequently, it utilizes these representations to acquire the visit representation, contributing to the final prediction. Furthermore, KAME [13] incorporates medical prior knowledge throughout the entire process, encompassing the learning of code representations, the generation of visit representations, and the prediction. It learns representations of medical prior knowledge and their ancestral codes respectively. The patient visit vectors, generated using medical codes and associated high-level knowledge, are then used to compute attention weights by knowledge attention to improve prediction accuracy. CAMP [14] contends that the aforementioned models fail to capture the fine-grained progression patterns of patients’ conditions and overlook patient demographic information [3, 33]. CAMP proposes a model, co-attention memory networks, that captures the fine-grained dynamic health status of patients. It utilizes key-value memory networks (KV-MNs) [34] to store various categories of information differentially. The model is designed as a three-way interactive neural network based on shared attention, integrating patient visit records, demographic information, and fine-grained patient status in a mutually reinforcing manner. Nevertheless, the approaches mentioned above do not fully harness the knowledge present in the knowledge bases. A medical knowledge base is essentially a graph comprising a set of medical concepts and their relationships. The nodes in the graph represent medical concepts, and the edges signify hierarchical relationships between concepts, i.e., parent–child relations. These approaches solely utilize the hierarchical relations in medical knowledge bases, neglecting the global structural information inherent in medical knowledge bases as a graph structure. This structure, a crucial feature of graphs, reveals the hidden pattern of the medical knowledge bases in the spatial domain. Furthermore, these approaches use medical prior knowledge as external information separated from EHR data, introducing additional noise. Consequently, GNDP [15] models patient visit records as spatial–temporal graphs that encompass medical concepts and medical codes, naturally infusing medical prior knowledge into patient visit data. Subsequently, a spatial–temporal graph convolutional network (ST-GCN) [35] is employed to capture the temporal correlation of visit data and the spatial correlation between medical knowledge, yielding a more robust and accurate patient representation.
Temporal knowledge graphs
Due to the influence of temporal changes on knowledge, temporal knowledge graphs have become an important research branch within the field of knowledge graphs. Researchers integrate temporal information into knowledge graphs in various ways to ensure the effectiveness of structured knowledge. [19, 36]. Among these works, two types are particularly relevant to our focus. One category is temporal information embedding, which incorporates temporal aspects into traditional knowledge representation and completion tasks. [37] studies time-range prediction on triples with temporal annotations. It defines a vector-based TTransE algorithm derived from TransE and uses a factorization machine (FM) to enhance the scalability and performance of the model. However, despite incorporating temporal properties through the addition of time ranges, the scalability of the model is limited. HyTE [38] extends TransH to solve the many-to-many relationship problem. This is achieved by partitioning the temporal knowledge graph into multiple static snapshots and subsequently projecting the entities and relations of each snapshot onto a timestamp-specific hyperplane. LSE4KGC [39] unites relations and timestamps, extending the TransE and DistMult models. It employs LSTM to learn a representation of the temporal knowledge graph, capturing temporal information. The model predicts timestamps in an incremental manner and can support multi-granularity predictions to some extent. ToKEi [40] believes that introducing a temporal dimension often imposes constraints on the granularity of time. Consequently, it puts forth a method capable of handling knowledge with diverse temporal granularity. This method supports multiple time points and valid intervals, including discontinuous intervals, and can elegantly adjust temporal validity for nearly arbitrary granularity. DyHE [41] considers spatial information in addition to temporal information. It extends each triple into a quintuple , where l represents location and t represents time. Subsequently, it employs the Dihedron algebra, which is a rich 4D algebra in hypercomplex spaces. This algebra provides a suitable theoretical foundation for the embedding of quintuples considering time and location, offering an effective means to capture the spatio-temporal information inherent in knowledge.
The other category focuses on dynamic entities, studying the impact of temporal changes on entities and their corresponding relations. Know-Evolve [42] employs stochastic processes to model the evolution of events over time and utilizes RNNs to learn the non-linearly evolving entities, enabling the study of the evolution of entities and relations as well as the acquisition of entity representations at different moments. However, due to the lack of contextual information incorporation, the model fails to capture the interaction between nodes. On the other hand, RE-NET [43] models temporal sequences using a combined RNN-based event encoder and neighbor aggregator. Specifically, it employs RNNs to capture the temporal relationships between entities and neighbor aggregators to aggregate interactions occurring simultaneously. RE-GCN [44] further considers the structural dependencies among concurrent facts in knowledge graphs, the sequential patterns of temporally adjacent facts, and the static attributes of entities. It effectively models all historical information in Temporal Knowledge Graphs as evolving representations, which are simultaneously applicable for both entity and relation prediction. [45] argues that existing methods fail to explicitly capture latent relations between co-occurring entities within a time slice and overlook latent relations spanning across different time slices among entities. Therefore, a model for latent relations learning is proposed. The model utilizes a structural encoder to obtain representations for each historical entity, uncovers latent relational graphs by exploring missing associative relations between entities, and then extracts temporal representations from the structural encoder and latent relations learning module for the final prediction task. This effectively harnesses the capability of capturing crucial latent associative relations.
Graph neural networks
Graphs, as non-Euclidean data structures, possess flexible scalability and semantic expressive capabilities. Due to the rapid development of graph neural networks, numerous models exhibiting excellent performance have been developed [16, 17, 46]. These models have found successful applications in various fields, addressing diverse problems. Among them, dynamic graph convolution and spatial–temporal graph convolution methods are the most relevant works to this paper.
In the realm of dynamic graph convolution methods, DySAT [20], a stacked dynamic graph neural network, and EvolveGCN [47], an integrated dynamic neural network, stand out as the two most representative works. DySAT models a dynamic graph as a series of snapshots to capture the evolving patterns over time. For each snapshot, it initially computes the local information of any node in the graph using graph attention. Then on top of that, it computes the dynamic information of the nodes over time using temporal self-attention. To mitigate the impact of frequent changes in the node set on the model, EvolveGCN treats the parameters of the same layers of CGN at each moment as a sequence and uses RNN to adjust the GCN parameters at each time step, capturing the dynamics in the evolving network parameters. This approach concentrates solely on the model itself, not on the nodes, thereby avoiding limitations arising from changes in the nodes. ROLAND [48] posits that existing dynamic Graph Neural Networks (GNNs) have not integrated the state-of-the-art designs of static GNNs, limiting their performance. Consequently, it reemploys static GNNs for dynamic graphs while preserving the effective design of static GNNs. Additionally, it introduces a live-update pipeline, simulating real-world usage by allowing dynamic updates during evaluation. In the realm of dynamic graph representation learning, the design of existing two-stage models overlooks the temporal dynamics between consecutive graphs. To address this issue, DynGNN [49] incorporates an RNN into the graph neural network, integrating the two-stage framework into a single-stage graph representation model, resulting in a more compact form that generates superior representations. This approach considers the fusion of temporal and topological correlations in feature learning from low-level to high-level features, enabling the model to capture the evolutionary patterns of dynamic networks during the feature learning process.
In terms of spatial–temporal graph convolution, ST-GCN [35] employs the spatial–temporal graph for skeleton-based action recognition. It initially models human actions as a spatial–temporal graph with joint coordinates as nodes and connections of body structures and time frames as edges. The model then aggregates spatial dimensional information using GCN and temporal dimensional information using TCN, which is actually a general CNN. In the domain of traffic, STSGCN [21] models the traffic network as a spatial–temporal graph structure. It employs a spatial–temporal synchronous graph convolution module to capture spatial–temporal correlations. This enables the model to consider not only the features of the spatial dimension and the features of the temporal dimension in the spatial–temporal graph but also the features of the spatial–temporal dimension.
Whether it’s temporal knowledge graphs, dynamic graphs, or spatial–temporal graphs, they have all developed in different domains and stages. Essentially, they are research methods for graphs with dynamic properties from various dimensions. Inspired by these approaches, this paper integrates dynamic graph modeling and representation learning methods that can express rich semantics to improve the performance of medical clinical diagnosis prediction.
The MDPG method
In this section, we introduce our MDPG model, a multi-disease diagnosis prediction model based on graph neural networks and attention mechanisms. First, we give a formal description of the concepts related to diagnosis prediction. Then, we detail the design and implementation details of each component of the model.
Basic notations
Definition 1
(Patient visit records) We denote all unique medical codes (ie. diagnosis codes, procedure codes, and medication codes) that appear in the patient EHR dataset as the set , where denotes the number of medical codes. In addition, patients are denoted by the set P. For each patient with T visit records, his or her visit records can be represented as a sequence . Each visit in this sequence can be represented as a subset of the medical code set as defined above, i.e. .
Definition 2
(Medical knowledge bases) A medical knowledge base , constructed using CCS, is a hierarchical structure containing various medical concepts with parent–child relations. Each leaf node in represents a specific medical concept, similar to medical codes in the EHR dataset. Each ancestor node provides a more general categorical description of its leaf nodes. Different levels of ancestor nodes represent different granularity of abstraction. all ancestor nodes in constitute the set . Given a set C of medical codes, we can obtain a subset of medical concepts containing medical code nodes and all ancestor nodes associated with them in a medical knowledge base . We represent as an undirected graph to ensure that information can flow in both directions between nodes.
Definition 3
(Patient demographic information) Patient demographic information consists of medically relevant patient characteristics that significantly influence diagnosis prediction, including factors like gender and age. In this paper, we use the set to denote the set of attribute codes for patient demographic characteristics. The continuous variables in it are discretized in the manner of [33]. For example, we divide the value field of age into several groups.
Definition 4
(Diagnosis prediction) With the above notation, the diagnosis prediction task is defined as follows: given a sequence representing a patient’s previous visit records, a medical knowledge base , and patient demographic information D, the diagnosis prediction aims to generate the medical code contained in the next visit record .
Modeling of patient knowledge graphs
Patient EHR data is a typical complex spatial–temporal knowledge, where the knowledge exhibits spatial and temporal correlations. Knowledge graphs are a semantic network based on graph data structure, which can highly abstract and summarize objective facts of the real world and visually describe complex semantic information. This provides a favorable modeling tool for EHR data. Therefore, we reconstruct the patient visit information into a patient-centered knowledge graph with temporal characteristics, referred to as the patient knowledge graph (PKG), to express the complex semantics embedded in the EHR data. It can comprehensively and precisely describe the spatial semantic correlations of visit information in a single visit and also clearly depict the temporal dependencies of visit information across multiple visits. In comparison to previous approaches that model EHR data as a collection of medical codes or a medical concepts-based tree structure, it is evident that patient knowledge graphs can better organize visit information and completely and clearly represent the spatial–temporal semantics among data in the clinical diagnosis and treatment process.
Given the input, which consists of patient historical visit records, medical knowledge bases and patient demographic information, we model it as an undirected multi-relational heterogeneous graph . V is the set of nodes consisting of medical codes, medical concept codes, and attribute codes indicating demographic characteristics. R is the set of relations describing semantic information on the edges in the graph. Relation types encompass hierarchy, diagnosis, medication, procedure, age, gender, and chronology. L is the set of edges representing links between nodes. T is the number of visits. Only patients with two or more visits are selected. Therefore, T is a positive integer greater than or equal to two.
Figure 2 illustrates an example of a patient knowledge graph. In the figure, each visit record of the patient is modeled as a multi-relational heterogeneous subgraph , where Pt denotes the patient, labels on the edges denote semantic relations, class D nodes denote the patient’s diagnosis codes, class P nodes denote the patient’s procedure codes, and class M nodes denote the patient’s medication codes. Nodes labeled male and aged signify the patient’s demographic information. All medical codes of the patient are aligned with the medical codes in the medical knowledge base and subsequently linked to their parent nodes (containing all nodes on the path to the root node), forming a multi-order neighborhood with parent–child relations. Finally, multiple subgraphs are arranged in the order of patient visits to express the temporal relationships.
Fig. 2.
An example of a patient knowledge graph
The framework of MDPG
In this section, we first describe the framework of our MDPG model, and then detail the specific implementation details of each component. Finally, the objective function is given.
The overall framework of MDPG is depicted in Fig. 3. The whole framework consists of four core components. The first component is the spatial convolution module, illustrated as (a) in Fig. 3, which captures the spatial relevance of knowledge in the PKG through the graph attention mechanism. The second component, depicted as (b) in Fig. 3, is the temporal self-attention module, which further captures the temporal relevance of knowledge in the PKG using the self-attention mechanism based on the spatial convolution module. The third component, illustrated as (c) in Fig. 3, is the spatial–temporal synchronous graph convolution module, which adopts the spatial–temporal graph model to learn the spatial–temporal correlation of knowledge in the PKG. Utilizing the aforementioned three components, we obtain patient representations in spatial, temporal, and spatial–temporal dimensions, respectively. Subsequently, we fuse the representations to generate a more robust and accurate patient representation and feed it into the fourth component-the prediction module. The prediction module achieves the diagnostic classification of multiple diseases using multi-label learning techniques. Illustrated as (d) in Fig. 3. In the following, we elaborate on each of these four components.
Fig. 3.
Overview of our proposed MDPG method
Spatial graph convolution block
To capture the local spatial correlation of patient diagnosis and treatment knowledge contained in each visit subgraph of the PKG, the spatial graph convolution block is designed as a multilayer graph convolution network consisting of multiple stackable spatial self-attention layers. The overall structure is illustrated as (a) in Fig. 3. At each time step of the patient visit sequence, we capture the structural and semantic information of the patient’s local neighbors through separate spatial self-attention layers. After multiple layers of information extraction, higher-order features of the patient node’s neighbors can be captured.
The input to the spatial self-attention layer consists of a subgraph of patient visits, a set of vectors representing all nodes in the graph, and a set of vectors representing all relations in the graph, where D denotes the dimension of the vectors. The layer aims to extract semantic features from the neighboring nodes and corresponding edges of each node in the graph by considering the influence of the graph structure and then calculate the representation (a.k.a. the distributed vector representation) of each node and relation. Since the edges of graph have semantic relations, following DySAT [20] and referencing the convolution calculation of GAT [17] and compGCN [46], when computing the representation of each node, we consider both the influence of neighboring nodes and the influence of semantic relations on the edges connected to the neighboring nodes. Additionally, the influence of the node on itself through self-looping edges is taken into account. The PKG in this paper is a non-attribute graph, and the consideration of nodes or edges with attributes is excluded. Thus, in the initialization stage, one-hot encoding is applied to nodes and edges. Specifically, in the patient’s visit subgraph , for any node v, as shown in Eq. 1, we first compute the attention scores between v and all its first-order neighbor nodes using its representation , its neighbor node representation and the corresponding relation representation . Then, it is transformed into an assignable weight coefficient by the softmax function. This coefficient reflects the degree of semantic contribution of neighbor nodes to node v. Finally, the representation of all neighbor nodes corresponding to node v, including v itself, are weighted and summed using the weight coefficients according to Eq. 2. This process yields the final representation of node v at this layer, denoted as , where F denotes the dimension of the output vector.
| 1 |
| 2 |
where are shared, trainable weight matrices that map the nodes and relations in to the low-dimensional embedding space; is a weight vector parameterizing the attention mechanism implemented as a single-layer feedforward neural network; is the set of first-order neighbor nodes of node v; denotes the concatenation operation; is a nonlinear activation function, such as LeakyReLU; denotes transposition.
As the influence of the semantic relation R on the edges is taken into account, it is necessary to update the relation after the node update in each layer in order to ensure the ongoing provision of relation input for the subsequent layer. The update formula is as follows:
| 3 |
To stabilize the learning process of the self-attention, we further employ a multi-head attention mechanism to capture the features of each input in different subspaces. By concatenating the features of each head, we obtain the final feature representation. The formula is as follows:
| 4 |
where K is the number of attention heads; We choose the representation sequence of the patient node as the final output of the spatial graph convolution block. Subsequently, it is fed as an input to the temporal self-attention block.
Temporal self-attention block
The objective of the temporal self-attention block is to capture temporal correlations among the patient representations of each visit subgraph in the PKG. It consists of a location encoding layer, multiple stacked temporal self-attention layers and a feedforward layer.
The location encoding layer is responsible for encoding the absolute temporal position of the patient representation of each visit subgraph in the whole knowledge graph and recording the temporal location information contained in the patient’s visit records. Given a with T visit records, it employs the location encoding computation method outlined in [50] to obtain a location encoding sequence corresponding to the patient representation sequence generated by the spatial graph convolution module, where , denotes the corresponding position code in .
The temporal self-attention layer employs a scalar dot product self-attention calculation method based on the scaled dot-product in the query-key-value form, as proposed by [18]. This method capture the correlation between the patient representation at any moment and the patient representation at other moments. The query, key, and value denote the values of query, key, and value of the input vector, respectively. After obtaining the location encoding through the location encoding layer, we add it to the patient representation sequence as the input for the temporal self-attention layer, denoted as . Following the temporal self-attention calculation, the resulting sequence serves as the output. For the sake of description, we denote the input and output of this layer as matrices and , respectively. As depicted in Eq. 5, we first project the input matrix to low-dimensional spaces using three linear projection matrices: , and . This process extracts the query, key, and value, resulting in the Q, K, and V matrices, respectively. Subsequently, attention scores are computed by taking the inner product of Q and K. Following that, the attention scores undergo normalization to form a matrix of assignable attention coefficients that conform to a probability distribution. These attention coefficients determine how much attention should be given to the patient representations at other moments in the visit sequence when learning the current patient representation. Finally, the coefficients in this coefficient matrix are weighted and summed with the corresponding values in V to obtain the patient representation matrix . Obviously, for any patient representation , the model learns the temporal correlation within the patient visit sequence through temporal self-attention calculation. In matrix calculations, a higher score at a position results in a larger value after multiplication. Hence, the model assigns more attention to that position. Conversely, a lower score leads to a smaller value after multiplication, and the model assigns less attention to it.
| 5 |
where softmax is a normalization function to ensure that weight scores conform to the probability distribution. is a scaling factor to decouple the correlation between the distribution of softmax and the dimension so that the gradient values remain stable during the training process. We choose the patient representation in the last visit as the final output of this layer, denoted as . Therefore, we do not need to use a mask matrix to prevent current patient representation from seeing future visits during our calculations.
Similarly, we employ multi-head attention in the temporal self-attention layer to enhance model stability. The equation is as follows:
| 6 |
where K is the number of attention heads;
After the temporal self-attention layer, a feed-forward neural network layer is picked up for dimension transformation. The formula is as follows:
| 7 |
where is the learnable parameter matrix; is the bias; and is a nonlinear activation function, such as Relu.
Spatial–temporal synchronous graph convolution block
As mentioned previously, nodes in the PKG exhibit not only spatial and temporal correlations but also crucial spatial–temporal correlations. Spatial–temporal correlation implicitly signifies the synchronous dependence of nodes in both temporal and spatial dimensions, a factor overlooked by existing methods. The aforementioned spatial and temporal blocks capture the spatial and temporal correlations of patients in visit records, respectively. To further synchronously capture the spatial–temporal dependencies of each node in the PKG and obtain the spatial–temporal representation of patients in the spatial–temporal dimension, we draw on the idea of STSGCN [21] and design the spatial–temporal synchronous convolutional block to capture spatial–temporal dependencies of patients in visit records.
For the sake of capturing the influence of both spatial neighbor nodes and temporal neighbor nodes on patient nodes. The PKG is partitioned into consecutive local spatial–temporal graphs based on the chronological order of visits. Leveraging the structural information of each local spatial–temporal graph, we can directly capture the dependencies between nodes and their spatial–temporal neighbors. Specifically, in the initial step, we establish connections between the same node and its previous and next time steps through a temporal relation, creating a local spatial–temporal relationship among the nodes. Obviously, connecting nodes with themselves across adjacent time steps intuitively captures local spatial–temporal relationships between nodes [21]. Since the nodes on each visit are dynamically changing, as opposed to STSGCN, we only connect the nodes that are both present on the adjacent time steps. Then, we construct a local spatial–temporal graph using every three-time steps as a basic unit to observe the spatial–temporal characteristics of each node. As shown in (c) in Fig. 3, the patient node can perceives not only the impact of each medical event during the current visit but also the impact of medical events in the adjacent visits. Therefore, in each local spatial–temporal graph, we use the patient node in the middle time step as the central node. In this way, the patient node can see not only the impact of medical events in the current visit but also the impact of medical events before and after the visit.
Next, we design a spatial–temporal synchronous graph convolution unit to learn patient representation using a local spatial–temporal graph as input. As the local spatial–temporal graph incorporates three time-step visit subgraphs, the unit employs a minimum of two stacked graph convolution operations to broaden the aggregation region and capture at least two hops of spatial–temporal neighbor node features. To complete node updates, each graph convolution operation also uses the GAT. The difference is that, for each node, we assign attention weights to different neighbor nodes by considering their both spatial and temporal neighbors, as well as their corresponding relations. Specifically, a single graph convolution operation is defined as follows:
| 8 |
| 9 |
where are trainable parameter matrices that map nodes and relations in to the low-dimensional vector space respectively; is a parameter vector for extracting attention scores. is the set of first-order spatial–temporal neighbor nodes of node v; denotes the concatenation operation; is a nonlinear activation function, such as LeakyReLU; and denotes transposition. After the nodes are updated, the relations need to be updated as well for the next layer. This is expressed as . Following multiple graph convolution operations, we select the representation of the patient node at the intermediate time step as the final output of each spatial–temporal synchronous graph convolution unit.
To capture the spatial–temporal correlation of patient information in the whole PKG, the spatial–temporal synchronous graph convolution block is formed by stacking multiple spatial–temporal synchronous graph convolution units. We utilize a sliding window mechanism to generate local spatial–temporal graphs. Specifically, we employ a fixed-size sliding window with three time steps, setting the stride to 1. For the PKG with time steps, we partition them into local spatial–temporal graphs. These are sequentially fed into spatial–temporal synchronous graph convolution units, resulting in a collection of patient distributed representations . This collection contains T-2 patient distributed representations. After that, we concatenate the patient representations in this collection and feed them into a feed-forward neural network layer to generate the final result. The formula is as follows:
| 10 |
where is the learnable parameter matrix; is a bias; is a nonlinear activation function, such as Relu. contains the spatial–temporal correlation implied by the entire patient knowledge graph. Similarly, to stabilize the learning process, we employ a multi-head attention mechanism to extend the spatial–temporal synchronous convolutional block. Further details are not reiterated in this context.
Prediction layer
After the spatial graph convolution block, temporal self-attention block, and spatial–temporal synchronous graph convolution block, MDPG generates two patient representations respectively. In the prediction layer, these representations are fused through weighted summation to create a unified, semantically enhanced patient representation. Subsequently, the representation undergoes transformation into a new vector space through a fully connected layer with activation functions. This process further extracts semantic features and produces the model’s final output, serving as the diagnosis prediction result. We consider the diagnosis prediction as a multi-label classification problem. Consequently, we transform the output results into a Bernoulli distribution using the sigmoid function. The specific definition is as follows:
| 11 |
where is the prediction result, whose dimension size is consistent with the number of medical codes; is the matrix of learnable parameters of the fully connected network; and is a scalar, which is also a learnable weight parameter to control the impact of spatial–temporal synchronous graph convolution.
Objective function Based on Eq. 11, we employ the following binary cross-entropy function as the objective function to calculate the loss between the value of the patient’s diagnosis prediction and its corresponding label value, i.e., ground truth .
| 12 |
Experiments
In this section, we conduct experiments on two real-world medical datasets and evaluate the performance of our model. Compared with the state-of-the-art prediction models, Our model yields higher accuracy on different evaluation strategies. Prior to delving into the experimental details, we first provide an overview of the two medical datasets. Next, we list the baseline models for comparison. Following that, we describe the relevant performance evaluation metrics. Lastly, we analyze the experimental results.
Data description
Two datasets with typical characteristics are selected to separately evaluate the performance of the models in different aspects. One is the MIMIC-III Critical Care dataset and the other is the MedClin Clinical Diagnosis and Treatment dataset.
Data set-I
The MIMIC-III dataset [51, 52] serves as a publicly available EHR dataset that is widely employed in diverse medical clinical prediction tasks. It comprises 7499 medical records from intensive care unit patients spanning an 11-year period. Compared to other health insurance and specialty disease datasets, this dataset, which contains a relatively small number of patients and patient visits, exhibits the characteristic of insufficient data. Leveraging this characteristic, we can better evaluate the performance of models when the training data is not sufficient.
Data set-II
The MedClin dataset is an in-hospital clinical diagnosis and treatment dataset. Due to information security considerations, the MedClin dataset is presently inaccessible. This dataset was collected and constructed by the Neusoft Clinical Decision Support System in a real production environment. It contains patient data with longer diagnostic series from 2010 to 2020. In comparison to the MIMIC-III dataset, patients in the MedClin dataset exhibit a higher frequency of visits Therefore, using this dataset, we can effectively evaluate the impact of the number of consecutive visits on the models.
Given that the characteristics of the aforementioned two datasets encompass real-world scenarios in patient diagnosis prediction, conducting experiments on them allows for a more effective and comprehensive evaluation of model performance. In order to further improve data quality, we perform additional inclusion and exclusion operations on the original data. Specifically, for the MIMIC-III dataset, we retain patients with at least two visits. For the MedClin dataset, we filter out patients with fewer than 5 visits. More details of both datasets are listed in Table 1.
Table 1.
Statistics of MIMIC-III and MedClin datasets
| Dataset | Data set-I | Data set-II |
|---|---|---|
| No. of patients | 7499 | 17,052 |
| No. of visits | 19,911 | 307,501 |
| Average no. of visits per patient | 2.66 | 18.03 |
| No. of unique ICD9 codes | 4880 | 5001 |
| Average no. of ICD9 codes per visit | 13.06 | 3.31 |
| Maximum no. of ICD9 codes per visit | 39 | 25 |
| No. of category codes | 171 | 160 |
| Average no. of category codes per visit | 10.16 | 2.39 |
| Maximum no. of category codes per visit | 30 | 15 |
ICD International Classification of Diseases
Experimental setup
In this section, we first introduce the baseline models, then describe the metrics used to evaluate the predictive models, and finally, we describe the implementation details.
Baseline methods
In order to thoroughly validate the effectiveness of MDPG, we choose the following state-of-the-art models for comparison. They are categorized into three groups: one comprises a variant of MDPG, another involves a knowledge-guided prediction model, and the last one is a model that does not utilize external prior knowledge.
Variants of MDPG
MDPG-. MDPG- excludes the spatial–temporal synchronous graph convolution block. Only spatial graph convolution block and temporal self-attention block are used.
MDPG-. MDPG- eliminates the semantic relations from the knowledge graph. The patient knowledge graph is viewed as a single relation graph.
MDPG-. MDPG- omits patient demographic information.
Knowledge-guided model
GRAM [12]. GRAM is the first model that uses a medical knowledge base as supplementary knowledge to learn medical code representation and achieve diagnosis prediction. It uses GRU to capture the temporal relationships between visit records and uses the attention mechanism to fuse features of medical codes and medical concepts. We employ a variant of GRAM, denoted as GRAM+, to learn the basic representation of medical codes and medical concepts through co-occurrence information.
KAME [13]. KAME incorporates medical knowledge directly into the entire prediction process, building upon the GRAM framework. Thus, within the GRAM framework, we generate representation vectors of prior medical knowledge using additional modules. These vectors are then concatenated with the hidden layer vectors generated by RNN and fed into the classifier for diagnosis prediction.
CAMP [14]. In addition to medical prior knowledge, CAMP supplements patient demographic information to improve predictive performance. To maintain experimental consistency, we only select two characteristics: age and gender.
GNDP [15]. GNDP is the pioneering model that employs graphs to consider the internal structure and semantics of patient visit records. It transforms patient visit information into a patient feature matrix and a graph adjacency matrix using a medical knowledge base. Subsequently, it learns the patient representation through convolution and graph convolution to achieve diagnosis prediction. We utilize the full version of GNDP, exhibiting the best performance for prediction.
Non-knowledge model
Dipole [8]. Dipole achieves the best performance compared to other models that use only RNNs. It uses a bidirectional recurrent neural network to address the challenge of long sequences of visit records. Additionally, it learns the significance of each visit for future predictions using three types of attention. In our implementation, we utilize a variant of location-based attention.
RNN. We use a single directional GRU [53] as an RNN baseline to implement diagnosis prediction.
GCN [16]. We use the graph neural network proposed in [16] as a baseline, which consists of two layers and cannot handle temporal correlation.
Evaluation metrics
Similar to prior studies [12, 13, 15], we adopt two metrics, namely visit-level precision@k and code-level accuracy@k, to measure the performance of all baseline models from different granularities, respectively.
Visit-level precision@k is a metric that measures the prediction performance of individual patient visits and is used to evaluate the coarse-grained performance of diagnosis prediction models. In the patient diagnosis prediction results, it is defined as the percentage of correct outcomes among the top k predicted diagnoses with the highest probability. The formula is as follows:
| 13 |
where denotes the number of correct medical codes contained in the predicted outcomes ranked within the top k ordered by probability; denotes the number of true values in the target visit, i.e., the number of medical codes with a label value of 1.
In contrast, code-level accuracy@k is a metric that measures the overall accuracy of multiple patient predictions and is used to evaluate the performance of diagnosis prediction models at a fine-grained level. The definition is as follows:
| 14 |
where denotes the number of patients.
In our experiments, we vary the k-value from 5 to 30 to evaluate the performance of each model. For both metrics, larger values indicate better performance.
Implementation details
We implement our model with PyTorch 1.8. For the baseline models, the model parameters are set the same as when they were proposed. We split the dataset into 75% for training, 10% for validation, and 15% for testing. The validation set is employed for selecting the optimal parameter values. We conduct the model training and testing on a Tesla V100 32G GPU. During the training phase, we use the Adma optimizer [54] as the optimizer of the model and update the gradients using a minibatch of 100 patients. To prevent overfitting, we use two regularization strategies: L2 normalization with a coefficient of 0.00001 and dropout with a rate of 0.25. The learning rate is set to 0.005. To improve the stability of the training, we vary the learning rate using the Constant Warmup method [55]. We linearly increase the learning rate for the initial 10% of training steps and decrease it thereafter.
Results and discussion
To demonstrate the necessity and effectiveness of MDPG’s various components, we conduct an ablation experiment on the MIMIC-III dataset. We evaluate MDPG variants using two types of metrics: visit-level precision and code-level accuracy. The evaluation results are presented in Table 2. Overall, the performance of MDPG outperforms all its variants in all aspects. This strongly suggests that each component of the model contributes to overall performance improvement.
Table 2.
Results of ablation experiment in terms of code-level accuracy and visit-level precision
| Code-level accuracy@K | Visit-level precision@K | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Data set | Model | 5 | 10 | 15 | 20 | 25 | 30 | 5 | 10 | 15 | 20 | 25 | 30 |
| Data set-I | MDPG | 0.3627 | 0.5609 | 0.6897 | 0.7817 | 0.8401 | 0.8781 | 0.7622 | 0.6892 | 0.7270 | 0.7980 | 0.8431 | 0.8897 |
| MDPG- | 0.3591 | 0.5512 | 0.6810 | 0.7786 | 0.8339 | 0.8737 | 0.7552 | 0.6852 | 0.7208 | 0.7890 | 0.8361 | 0.8810 | |
| MDPG- | 0.3519 | 0.5506 | 0.6799 | 0.7698 | 0.8281 | 0.8700 | 0.7506 | 0.6846 | 0.7205 | 0.7882 | 0.8359 | 0.8807 | |
| MDPG- | 0.3500 | 0.5495 | 0.6779 | 0. 7674 | 0.8259 | 0.8685 | 0.7488 | 0.6839 | 0.7196 | 0.7866 | 0.8355 | 0.8793 | |
The best values are highlighted in bold
Specifically, (1) when MDPG removes the spatial–temporal synchronous graph convolution block, the code-level accuracy and visit-level precision decrease by more than 1.2% when k = 5, respectively. This demonstrates that the spatial–temporal synchronous graph convolution can further explore the potential correlations of patient EHR data in the spatial–temporal dimension, which cannot be explored in the spatial or temporal dimension alone. (2) When the patient knowledge graph is constructed without considering semantic relations, the evaluation results on the dataset are significantly reduced. When k = 5, the code-level accuracy and visit-level precision decrease by 1.08% and 1.16%, respectively. This indicates that modeling patient visit information as a knowledge graph can effectively enhance the semantic representation of diagnosis and treatment data. It is also easy to understand that the symbolic graph representation of knowledge graphs not only provides a well-structured organization of knowledge but also presents complex semantics among knowledge, thus allowing a complete description of the patient diagnosis and treatment process. (3) Disregarding patient demographic information also has a negative impact on the model. Removing patient demographic information, even as simple as gender and age, can degrade the performance of the model. Therefore, in the diagnostic prediction process, although prior knowledge such as knowledge bases can serve as supplements to reduce noise interference and improve accuracy, personalized patient data is equally important.
From this, it can be observed that modeling patient visit data, introducing external medical prior knowledge, and employing efficient inference algorithms are key factors in improving the performance of diagnosis prediction models. Our model precisely addresses each of these three aspects to enhance predictive performance. To validate the advancedness of MDPG, we compare it with each baseline model on Data set-I and Data set-II, respectively. Indeed, MDPG achieves the best results among all baseline models.
As shown in Table 3, we can observe that (1) Dipole and RNN perform better than GCN, which indicates that it is very important to consider the temporal characteristics of EHR data in the diagnosis prediction process, better than just considering spatial correlation. (2) GRAM, KAME, and CAMP, which introduce external knowledge, outperform the models that don’t introduce external knowledge. In particular, the values of code-level accuracy and visit-level precision metrics of CAMP are higher than those of the other models, which indicates that the models sensitive to patient demographic information can capture more fine-grained features. (3) Apart from MDPG, GNDP emerges as the best performer among these models. This implies that, under the condition of introducing external prior knowledge, prediction performance can be further improved by effectively modeling patient EHR data and mining the correlation of visit information from its spatial and temporal structures. Compared with GNDP, on data set-I, when k = 5, 15, 30, for code-level accuracy, MDPG improves by 1.95%, 1.96%, and 1.52%. For visit-level precision, MDPG improves by 1.89%, 0.88%, and 1.48%. On data set-II, when k = 5, 15, 30, for code-level accuracy, MDPG improves by 2.28%, 2.43%, and 1.88%. For visit-level precision, MDPG improves by 2.64%, 1.44%, and 1.5%.
Table 3.
Results of comparative experiments-I in terms of code-level accuracy and visit-level precision
| Code-level Accuracy@K | Visit-level Precision@K | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Data set | Model | 5 | 10 | 15 | 20 | 25 | 30 | 5 | 10 | 15 | 20 | 25 | 30 |
| Data set-I | MDPG | 0.3627 | 0.5609 | 0.6897 | 0.7817 | 0.8401 | 0.8781 | 0.7622 | 0.6892 | 0.7270 | 0.7980 | 0.8431 | 0.8897 |
| GNDP | 0.3432 | 0.5401 | 0.6701 | 0.7571 | 0.8176 | 0.8629 | 0.7433 | 0.6766 | 0.7182 | 0.7811 | 0.8338 | 0.8749 | |
| CAMP | 0.3225 | 0.5173 | 0.6489 | 0.7285 | 0.7933 | 0.8457 | 0.7219 | 0.6680 | 0.7074 | 0.7623 | 0.8219 | 0.8541 | |
| KAME | 0.3167 | 0.5100 | 0.6379 | 0.7240 | 0.7862 | 0.8303 | 0.7103 | 0.6568 | 0.6967 | 0.7562 | 0.8091 | 0.8470 | |
| GRAM | 0.3123 | 0.5026 | 0.6296 | 0.7142 | 0.7798 | 0.8266 | 0.6698 | 0.6447 | 0.6847 | 0.7439 | 0.8007 | 0.8424 | |
| Dipole | 0.2774 | 0.4556 | 0.5801 | 0.6671 | 0.7354 | 0.7902 | 0.6220 | 0.5839 | 0.6310 | 0.6912 | 0.7542 | 0.8017 | |
| RNN | 0.2760 | 0.4548 | 0.5751 | 0.6647 | 0.7350 | 0.7867 | 0.6158 | 0.5803 | 0.6243 | 0.6912 | 0.7542 | 0.8017 | |
| GCN | 0.2465 | 0.3902 | 0.4909 | 0.5941 | 0.6790 | 0.7317 | 0.5526 | 0.5328 | 0.5751 | 0.6249 | 0.7011 | 0.7324 | |
| Data set-II | MDPG | 0.6742 | 0.8247 | 0.8801 | 0.9279 | 0.9419 | 0.9584 | 0.7120 | 0.8383 | 0.8897 | 0.9316 | 0.9418 | 0.9565 |
| GNDP | 0.6514 | 0.7946 | 0.8558 | 0.9008 | 0.9191 | 0.9396 | 0.6856 | 0.8213 | 0.8753 | 0.9054 | 0.9244 | 0.9415 | |
| CAMP | 0.5992 | 0.7523 | 0.8338 | 0.8754 | 0.9004 | 0.9265 | 0.6569 | 0.7972 | 0.8539 | 0.8928 | 0.9121 | 0.9330 | |
| KAME | 0.5984 | 0.7525 | 0.8262 | 0.8738 | 0.8999 | 0.9221 | 0.6554 | 0.7901 | 0.8515 | 0.8841 | 0.9108 | 0.9284 | |
| GRAM | 0.5696 | 0.7454 | 0.8215 | 0.8733 | 0.9019 | 0.9202 | 0.6399 | 0.7825 | 0.8224 | 0.8829 | 0.9111 | 0.9255 | |
| Dipole | 0.5737 | 0.6824 | 0.7582 | 0.8093 | 0.8443 | 0.8809 | 0.6318 | 0.7390 | 0.8011 | 0.8357 | 0.8658 | 0.8943 | |
| RNN | 0.5662 | 0.6748 | 0.7474 | 0.8004 | 0.8370 | 0.8754 | 0.6314 | 0.7376 | 0.7898 | 0.8276 | 0.8469 | 0.8919 | |
| GCN | 0.5092 | 0.5657 | 0.6380 | 0.7192 | 0.7283 | 0.7496 | 0.5603 | 0.6534 | 0.6860 | 0.7168 | 0.7258 | 0.7595 | |
The best values are highlighted in bold
MDPG uses the multi-relational graph-attention and self-attention mechanisms in the spatial, temporal, and spatial–temporal dimensions respectively. These mechanisms are ahead of graph convolution and convolutional neural network-based models at the algorithmic level. In terms of EHR data modeling, MDPG models patient visit records with temporal characteristics as a patient knowledge graph. This knowledge representation, based on multi-relational heterogeneous graphs, provides a more comprehensive description of the semantic information embedded in patient visit records compared to representations with only parent–child hierarchy. Furthermore, considering patient demographic information further enhances the performance of MDPG. Consequently, MDPG attains state-of-the-art results compared to other baseline models.
To make a more comprehensive comparison with the knowledge-guided models, we divide the dataset into training, validation and test sets in the ratio of 60%, 10% and 30% according to GNDP, and conduct further experiments. The results are shown in Table 4. Although the data in the training set is reduced by 15%, it has no adverse effect on the experimental results on both datasets. the MDPG still achieves the best results in all metrics. This shows the excellent stability of the model.
Table 4.
Results of comparative experiments-II in terms of code-level accuracy and visit-level precision
| Code-level Accuracy@K | Visit-level Precision@K | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Data set | Model | 5 | 10 | 15 | 20 | 25 | 30 | 5 | 10 | 15 | 20 | 25 | 30 |
| Data set-I | MDPG | 0.3599 | 0.5598 | 0.6841 | 0.7769 | 0.8373 | 0.8759 | 0.7571 | 0.6886 | 0.7213 | 0.7932 | 0.8405 | 0.8875 |
| GNDP | 0.3400 | 0.5387 | 0.6642 | 0.7520 | 0.8146 | 0.8606 | 0.7378 | 0.6754 | 0.7121 | 0.7760 | 0.8310 | 0.8725 | |
| CAMP | 0.3175 | 0.5084 | 0.6425 | 0.7269 | 0.7902 | 0.8400 | 0.7211 | 0.6649 | 0.7046 | 0.7537 | 0.8151 | 0.8435 | |
| KAME | 0.3117 | 0.5028 | 0.6296 | 0.7211 | 0.7822 | 0.8248 | 0.7057 | 0.6494 | 0.6910 | 0.7530 | 0.8049 | 0.8380 | |
| GRAM | 0.3041 | 0.4976 | 0.6217 | 0.7123 | 0.7737 | 0.8206 | 0.6647 | 0.6447 | 0.6847 | 0.7392 | 0.7913 | 0.8327 | |
| Data set-II | MDPG | 0.6723 | 0.8199 | 0.8807 | 0.9232 | 0.9413 | 0.9537 | 0.7072 | 0.8379 | 0.8869 | 0.9247 | 0.9329 | 0.9438 |
| GNDP | 0.6477 | 0.7885 | 0.8551 | 0.8950 | 0.9173 | 0.9340 | 0.6789 | 0.8201 | 0.8711 | 0.8974 | 0.9144 | 0.9280 | |
| CAMP | 0.5918 | 0.7501 | 0.8299 | 0.8696 | 0.8981 | 0.9254 | 0.6484 | 0.7942 | 0.8461 | 0.8872 | 0.9038 | 0.9255 | |
| KAME | 0.5971 | 0.7461 | 0.8237 | 0.8721 | 0.8879 | 0.9210 | 0.6501 | 0.7893 | 0.8444 | 0.8834 | 0.9086 | 0.9233 | |
| GRAM | 0.5640 | 0.7427 | 0.8128 | 0.8708 | 0.8909 | 0.9138 | 0.6337 | 0.7795 | 0.8181 | 0.8741 | 0.9037 | 0.9170 | |
The best values are highlighted in bold
Data sufficiency analysis
Referring to KAME, on Data set-I and Data set-II, we further conduct a data sufficiency experiment to evaluate the impact of data sufficiency on models. Initially, we rank the sample labels contained in the training set in order of their frequency of occurrence, from lowest to highest. Subsequently, we divide all labels into four groups, denoted as I, II, III, and IV, at 25% intervals. Group I contains the first 25% of the data, which indicates that it contains the rarest sample labels in the training set. Conversely, group IV, which contains the last 25%, contains the most common labels. Ultimately, we calculate the diagnosis prediction results of the models under each group and evaluate its performance by the code-level accuracy@20 metric to illustrate the impact of data sufficiency on the models. As depicted in Figs. 4 and 5, the x-axis denotes the baseline approaches, and the y-axis denotes the average accuracy of each model’s predictions. Specifically, we choose only competitive knowledge-guided models as a baseline for comparison.
Fig. 4.
Code-Level Accuracy@20 of diagnosis prediction on the Data set-I
Fig. 5.
Code-Level Accuracy@20 of diagnosis prediction on the Data set-II
Figure 4 shows the evaluation results of each model on data set-I. As can be seen, the MDPG achieves the best performance on all four groups, both in Group I, containing rare medical codes, and in Group IV, containing the most common medical codes. In addition, GNDP ranks second. This suggests that the modeling approach, incorporating knowledge graphs and utilizing graph structures while considering temporal features, can better model patient features, fuse external prior knowledge, and learn their underlying semantic representations, thereby improving model prediction performance. Figure 5 shows the evaluation results of each model on dataset-II. This is consistent with the findings on data set-I. In comparison, there is an improvement in accuracy. This is attributed to the fact data set-II contains richer feature information. It also shows that the models are also applicable to EHR data with longer visit records.
Conclusion
In this paper, we propose a novel clinical medical diagnosis prediction model, MDPG. It can predict the future health status of patients based on the historical diagnosis and treatment information in their electronic health records, thereby improving the efficiency of the doctor’s diagnosis and treatment, and reducing the rates of misdiagnosis and underdiagnosis. MDPG first models patient visit information as a patient knowledge graph, which achieves a comprehensive and lossless representation of medical record data. At the same time, the representation is further enriched by naturally integrating external information, such as medical prior knowledge and patient demographic information, in a scalable form. Then, based on the patient knowledge graph, the model uses graph neural networks and self-attention mechanisms to comprehensively learn the spatial correlation, temporal correlation, and spatial–temporal correlation among medical events in space, time, and space–time, respectively, which in turn generates a patient representation with deep semantics, and it is used to achieve high-performance multi-disease diagnosis prediction. This approach greatly reduces the impact of data sparsity and noise on prediction. We conduct comprehensive experiments on an open-source dataset and an in-hospital dataset, comparing the results with state-of-the-art models in the industry. The experimental results demonstrate that MDPG outperforms all baseline models and achieves superior performance. MDPG can better model sparse and complex EHR data and improve prediction performance.
Although MDPG has effectively addressed the challenges of high-dimensional, sparse and multi-noise in EHR data, it still has certain limitations. To better support the Clinical Decision Support System, future optimizations should focus on the following two aspects. (1) Incorporate additional logical semantic relations into the patient knowledge graph to enhance the description of more complex events in EHR data. (2) Consider multimodal information and further integrate medical imaging knowledge to improve prediction performance.
Acknowledgements
We would like to thank the anonymous reviewers for their insightful suggestions. Our work is supported by the National Key Research and Development Program of China (Grant No. 2020AAA0109400).
Author contributions
WW designed the study, performed measurements, designed the analysis, and wrote the manuscript. YF designed the schema for PKG and extracted the data. HZ designed the analysis and the schema for PKG. XW designed the analysis. RC cleaned the clinical data. WC designed the schema for PKG. XZ designed the study and the analysis. All authors contributed to the article and approved the submitted version.
Data availability
The MIMIC-III dataset are publicly available at https://mimic.mit.edu/. The MedClin dataset is not currently available as a private dataset.
Declarations
Competing interests
The authors declare no potential conflict of interests.
Footnotes
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
References
- 1.Birkhead GS, Klompas M, Shah NR. Uses of electronic health records for public health surveillance to advance public health. Annu Rev Public Health. 2015;36:345–59. [DOI] [PubMed] [Google Scholar]
- 2.Jensen PB, Jensen LJ, Brunak S. Mining electronic health records: towards better research applications and clinical care. Nat Rev Genet. 2012;13(6):395–405. [DOI] [PubMed] [Google Scholar]
- 3.Choi E, Bahadori MT, Searles E, Coffey C, Thompson M, Bost J, Tejedor-Sojo J, Sun J. Multi-layer representation learning for medical concepts. In: Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining, 2016;1495–1504.
- 4.Zhou J, Sun J, Liu Y, Hu J, Ye J. Patient risk prediction model via top-k stability selection. In: Proceedings of the 2013 SIAM international conference on data mining, 2013; 55–63. SIAM.
- 5.Choi E, Bahadori MT, Sun J, Kulas J, Schuetz A, Stewart W. Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. Adv Neural Inf Process Syst. 2016;29:3512–20. [Google Scholar]
- 6.Weston J, Chopra S, Bordes A. Memory networks. In: 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7–9, 2015, Conference Track Proceedings.
- 7.Song H, Rajan D, Thiagarajan J, Spanias A. Attend and diagnose: Clinical time series analysis using attention models. In: Proceedings of the AAAI Conference on Artificial Intelligence, 2018; vol. 32.
- 8.Ma F, Chitta R, Zhou J, You Q, Sun T, Gao J. Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In: Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, 2017; pp. 1903–1911.
- 9.Schuster M, Paliwal KK. Bidirectional recurrent neural networks. IEEE Trans Signal Process. 1997;45(11):2673–81. [Google Scholar]
- 10.Wei W-Q, Bastarache LA, Carroll RJ, Marlo JE, Osterman TJ, Gamazon ER, Cox NJ, Roden DM, Denny JC. Evaluating phecodes, clinical classification software, and ICD-9-cm codes for phenome-wide association studies in the electronic health record. PLoS ONE. 2017;12(7):0175508. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 11.Stearns MQ, Price C, Spackman KA, Wang AY. Snomed clinical terms: overview of the development process and project status. In: Proceedings of the AMIA Symposium, 2001; p. 662 [PMC free article] [PubMed]
- 12.Choi E, Bahadori MT, Song L, Stewart WF, Sun J. Gram: graph-based attention model for healthcare representation learning. In: Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, 2017; pp. 787–795. [DOI] [PMC free article] [PubMed]
- 13.Ma F, You Q, Xiao H, Chitta R, Zhou J, Gao J. Kame: Knowledge-based attention model for diagnosis prediction in healthcare. In: Proceedings of the 27th ACM international conference on information and knowledge management, 2018; pp. 743–752.
- 14.Gao J, Wang X, Wang Y, Yang Z, Gao J, Wang J, Tang W, Xie X. Camp: co-attention memory networks for diagnosis prediction in healthcare. In: 2019 IEEE international conference on data mining (ICDM), 2019; pp. 1036–1041. IEEE.
- 15.Li Y, Qian B, Zhang X, Liu H. Graph neural network-based diagnosis prediction. Big Data. 2020;8(5):379–90. [DOI] [PubMed] [Google Scholar]
- 16.Kipf TN, Welling M. Semi-supervised classification with graph convolutional networks. In: 5th international conference on learning representations, ICLR 2017, Toulon, France, April 24–26, 2017, Conference Track Proceedings.
- 17.Velickovic P, Cucurull G, Casanova A, Romero A, Liò P, Bengio Y. Graph attention networks. In: 6th international conference on learning representations, ICLR 2018, Vancouver, BC, Canada, April 30–May 3, 2018, Conference Track Proceedings.
- 18.Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I. Attention is all you need. Adv Neural Inf Process Syst. 2017;30:1–11. [Google Scholar]
- 19.Ji S, Pan S, Cambria E, Marttinen P, Philip SY. A survey on knowledge graphs: representation, acquisition, and applications. IEEE Trans Neural Netw Learn Syst. 2021;33(2):494–514. [DOI] [PubMed] [Google Scholar]
- 20.Sankar A, Wu Y, Gou L, Zhang W, Yang H. Dysat: Deep neural representation learning on dynamic graphs via self-attention networks. In: WSDM ’20: The Thirteenth ACM international conference on web search and data mining, Houston, TX, USA, February 3–7, 2020, pp. 519–527.
- 21.Song C, Lin Y, Guo S, Wan H. Spatial–temporal synchronous graph convolutional networks: a new framework for spatial–temporal network data forecasting. In: Proceedings of the AAAI conference on artificial intelligence, 2020; vol. 34, pp. 914–921.
- 22.Choi E, Du N, Chen R, Song L, Sun J. Constructing disease network and temporal progression model via context-sensitive Hawkes process. In: 2015 IEEE international conference on data mining, 2015; pp. 721–726. IEEE.
- 23.Wang X, Sontag D, Wang F. Unsupervised learning of disease progression models. In: Proceedings of the 20th ACM SIGKDD international conference on knowledge discovery and data mining, 2014; pp. 85–94.
- 24.Xiao H, Gao J, Vu L, Turaga DS. Learning temporal state of diabetes patients via combining behavioral and demographic data. In: Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, 2017; pp. 2081–2089.
- 25.Zhou J, Yuan L, Liu J, Ye J. A multi-task learning formulation for predicting disease progression. In: Proceedings of the 17th ACM SIGKDD international conference on knowledge discovery and data mining, 2011; pp. 814–822.
- 26.Che Z, Kale D, Li W, Bahadori MT, Liu Y. Deep computational phenotyping. In: Proceedings of the 21th ACM SIGKDD international conference on knowledge discovery and data mining, 2015; pp. 507–516.
- 27.Liu C, Wang F, Hu J, Xiong H. Temporal phenotyping from longitudinal electronic health records: a graph based framework. In: Proceedings of the 21th ACM SIGKDD international conference on knowledge discovery and data mining, 2015; pp. 705–714.
- 28.Zhou J, Wang F, Hu J, Ye J. From micro to macro: data driven phenotyping by densification of longitudinal electronic medical records. In: Proceedings of the 20th ACM sigkdd international conference on knowledge discovery and data mining, 2014; pp. 135–144.
- 29.Che Z, Purushotham S, Cho K, Sontag D, Liu Y. Recurrent neural networks for multivariate time series with missing values. Sci Rep. 2018;8(1):6085. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 30.Lipton ZC, Kale DC, Wetzel R, et al. Modeling missing data in clinical time series with RNNS. Mach Learn Healthcare. 2016;56:253–70. [Google Scholar]
- 31.Nguyen P, Tran T, Wickramasinghe N, Venkatesh S. Deepr: a convolutional net for medical records. IEEE J Biomed Health Inf. 2017;21(1):22–30. [DOI] [PubMed] [Google Scholar]
- 32.Cheng Y, Wang F, Zhang P, Hu J. Risk prediction with electronic health records: a deep learning approach. In: Proceedings of the 2016 SIAM international conference on data mining, 2016; pp. 432–440. SIAM.
- 33.Lee W, Park S, Joo W, Moon I-C. Diagnosis prediction via medical context attention networks using deep generative modeling. In: 2018 IEEE international conference on data mining (ICDM), 2018; pp. 1104–1109. IEEE.
- 34.Miller AH, Fisch A, Dodge J, Karimi A, Bordes A, Weston J. Key-value memory networks for directly reading documents. In: Proceedings of the 2016 conference on empirical methods in natural language processing, EMNLP 2016, Austin, Texas, USA, November 1–4, 2016, pp. 1400–1409.
- 35.Yan S, Xiong Y, Lin D. Spatial temporal graph convolutional networks for skeleton-based action recognition. In: Proceedings of the AAAI conference on artificial intelligence, 2018; vol. 32.
- 36.Liang K, Meng L, Liu M, Liu Y, Tu W, Wang S, Zhou S, Liu X, Sun F. Reasoning over different types of knowledge graphs: static, temporal and multi-modal. arXiv:2212.05767 [DOI] [PubMed]
- 37.Leblay J, Chekol MW. Deriving validity time in knowledge graph. In: Companion proceedings of the the web conference 2018; pp. 1771–1776
- 38.Dasgupta SS, Ray SN, Talukdar PP. Hyte: Hyperplane-based temporally aware knowledge graph embedding. In: EMNLP, 2018; pp. 2001–2011
- 39.García-Durán A, Dumancic S, Niepert M. Learning sequence encoders for temporal knowledge graph completion. In: Proceedings of the 2018 conference on empirical methods in natural language processing, Brussels, Belgium, October 31–November 4, 2018, pp. 4816–4821.
- 40.Leblay J, Chekol MW, Liu X. Towards temporal knowledge graph embeddings with arbitrary time precision. In: Proceedings of the 29th ACM international conference on information & knowledge management, 2020; pp. 685–694.
- 41.Nayyeri M, Vahdati S, Khan MT, Alam MM, Wenige L, Behrend A, Lehmann J. Dihedron algebraic embeddings for spatio-temporal knowledge graph completion. In: The Semantic Web—19th international conference, ESWC 2022, Hersonissos, Crete, Greece, May 29–June 2, 2022, Proceedings. Lecture Notes in Computer Science, 2022; vol. 13261, pp. 253–269.
- 42.Trivedi R, Dai H, Wang Y, Song L. Know-evolve: deep temporal reasoning for dynamic knowledge graphs. In: International conference on machine learning, 2017; pp. 3462–3471. PMLR.
- 43.Jin W, Qu M, Jin X, Ren X. Recurrent event network: autoregressive structure inference over temporal knowledge graphs. In: Proceedings of the 2020 conference on empirical methods in natural language processing, EMNLP 2020, Online, November 16–20, 2020, pp. 6669–6683
- 44.Li Z, Jin X, Li W, Guan S, Guo J, Shen H, Wang Y, Cheng X. Temporal knowledge graph reasoning based on evolutional representation learning. In: SIGIR ’21: The 44th international ACM SIGIR conference on research and development in information retrieval, virtual event, Canada, July 11–15, 2021, pp. 408–417.
- 45.Zhang M, Xia Y, Liu Q, Wu S, Wang L. Learning latent relations for temporal knowledge graph reasoning. In: Proceedings of the 61st annual meeting of the association for computational linguistics (Volume 1: Long Papers), ACL 2023, Toronto, Canada, July 9–14, 2023, pp. 12617–12631.
- 46.Vashishth S, Sanyal S, Nitin V, Talukdar PP. Composition-based multi-relational graph convolutional networks. In: 8th international conference on learning representations, ICLR 2020, Addis Ababa, Ethiopia, April 26–30, 2020.
- 47.Pareja A, Domeniconi G, Chen J, Ma T, Suzumura T, Kanezashi H, Kaler T, Schardl T, Leiserson C. Evolvegcn: Evolving graph convolutional networks for dynamic graphs. In: Proceedings of the AAAI conference on artificial intelligence, 2020; vol. 34, pp. 5363–5370.
- 48.You J, Du T, Leskovec J. ROLAND: graph learning framework for dynamic graphs. In: Zhang, A., Rangwala, H. (eds.) KDD ’22: The 28th ACM SIGKDD conference on knowledge discovery and data mining, Washington, DC, USA, August 14–18, 2022, pp. 2358–2366.
- 49.Zhang C, Yao Z, Yao H, Huang F, Chen CLP. Dynamic representation learning via recurrent graph neural networks. IEEE Trans Syst Man Cybern Syst. 2023;53(2):1284–97. [Google Scholar]
- 50.Gehring J, Auli M, Grangier D, Yarats D, Dauphin YN. Convolutional sequence to sequence learning. In: International conference on machine learning, 2017; pp. 1243–1252. PMLR.
- 51.Johnson AE, Pollard TJ, Shen L, Lehman L-WH, Feng M, Ghassemi M, Moody B, Szolovits P, AnthonyCeli L, Mark RG. Mimic-iii, a freely accessible critical care database. Sci Data. 2016;3(1):1–9. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 52.Goldberger AL, Amaral LA, Glass L, Hausdorff JM, Ivanov PC, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. Physiobank, physiotoolkit, and physionet: components of a new research resource for complex physiologic signals. Circulation. 2000;101(23):215–20. [DOI] [PubMed] [Google Scholar]
- 53.Cho K, Merrienboer B, Gülçehre Ç, Bahdanau D, Bougares F, Schwenk H, Bengio Y. Learning phrase representations using RNN encoder-decoder for statistical machine translation. In: Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing, EMNLP 2014, October 25-29, 2014, Doha, Qatar, A Meeting of SIGDAT, a Special Interest Group of The ACL, pp. 1724–1734.
- 54.Kingma DP, Ba J. Adam: A method for stochastic optimization. In: 3rd International conference on learning representations, ICLR 2015, San Diego, CA, USA, May 7–9, 2015, Conference Track Proceedings.
- 55.He K, Zhang X, Ren S, Sun J. Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition, 2016; pp. 770–778.
Associated Data
This section collects any data citations, data availability statements, or supplementary materials included in this article.
Data Availability Statement
The MIMIC-III dataset are publicly available at https://mimic.mit.edu/. The MedClin dataset is not currently available as a private dataset.





