Abstract
Objective
Prediction of disease phenotypes and their outcomes is a difficult task. In practice, patients routinely seek second opinions from multiple clinical experts for complex disease diagnosis. Our objective is to mimic such a practice of seeking second opinions by training 2 agents with different focuses: the primary agent studies the most recent visit of the patient to learn the current health status, and then the second-opinion agent considers the entire patient history to obtain a more global view.
Materials and Methods
Our approach Dr. Agent augments recurrent neural networks with 2 policy gradient agents. Moreover, Dr. Agent is customized with various patient demographics information and learns a dynamic skip connection to focus on the relevant information over time. We trained Dr. Agent to perform 4 clinical prediction tasks on the publicly available MIMIC-III (Medical Information Mart for Intensive Care) database: (1) in-hospital mortality prediction, (2) acute care phenotype classification, (3) physiologic decompensation prediction, and (4) forecasting length of stay. We compared the performance of Dr. Agent against 4 baseline clinical predictive models.
Results
Dr. Agent outperforms baseline clinical prediction models across all 4 tasks in terms of all metrics. Compared with the best baseline model, Dr. Agent achieves up to 15% higher area under the precision-recall curve on different tasks.
Conclusions
Dr. Agent can comprehensively model the long-term dependencies of patients’ health status while considering patients’ demographics using 2 agents, and therefore achieves better prediction performance on different clinical prediction tasks.
Keywords: clinical prediction, deep learning, recurrent neural network, electronic health records, reinforcement learning, intensive care
INTRODUCTION
OBJECTIVE
Predicting complex disease phenotypes and their outcomes is difficult even for the best human doctors. Patients routinely seek second opinions from other clinical experts to confirm their existing diagnosis to alleviate the issue of high variance. For example, A study from the Mayo Clinic shows that only 12% of patients who seek for second opinion received the same diagnosis as their original one.1 In those cases, it is clearly beneficial to have multiple clinical experts to evaluate the patient. Can we leverage this idea to improve artificial intelligence agents to perform better in clinical prediction? In particular, can multiple predictive agents work together to improve the prediction accuracy?
In this work, we propose a new multiagent reinforcement learning based model, named Dr. Agent, which trains 2 complementary reinforcement learning (RL) agents to comprehensively extract temporal dependencies in patients’ longitudinal electronic health record (EHR) data. Each agent is designed to focus on certain perspective of patients’ health status. We applied Dr. Agent to the publicly available MIMIC-III (Medical Information Mart for Intensive Care III) database.2 We trained Dr. Agent to perform 4 clinical prediction tasks: (1) in-hospital mortality prediction, (2) acute care phenotype classification, (3) physiologic decompensation prediction, and (4) forecasting length of stay (LOS). Improving the accuracy of clinical prediction tasks could help clinicians delaying or preventing risk events and reduce cost.
BACKGROUND AND SIGNIFICANCE
Deep phenotyping
EHR data from millions of patients are now routinely collected across diverse healthcare institutions. The availability of massive EHR data enables the training of complex deep learning models for clinical predictive models.3,4 Initial success has been demonstrated in estimating length of stay,5 patient phenotype identification,6 and disease or risk predictions.7
Extracting long-term dependencies and multimodal learning are challenges in EHR data analysis. To solve the former challenge, a diagnosis prediction model uses bidirectional recurrent neural networks (RNNs) and attention mechanism to extract the relationships of different visits.8 Time-aware long short-term memory networks (T-LSTMs) learn a representation capturing the dependencies between the visits in the presence of time irregularities.6 For multimodal learning, both Esteban et al9 and Lee et al10 leveraged hierarchy structures to incorporate different modalities.
Recently, in natural language processing tasks, Gui et al11 proposed an RL-based dynamic skip connections mechanism to model word dependencies. It allows LSTM cells to compute recurrent transition functions based on 1 optimal set of hidden and cell states from the past few states, and thus achieved higher prediction performance compared with traditional attention-based approaches. In this work, we draw the analogy between “words” and “medical codes” and extend the dynamic skip connection mechanism to EHR analysis tasks. Besides, we initialize learning environment with patient demographics. The patient demographic information (eg, age, gender) was incorporated into the learning environment so that the agents can have better contextual information of patients in order to decide which historical visits are more important for prediction.
Multiagent reinforcement learning
Recently, reinforcement learning algorithms have been proposed to identify decision-making strategies for clinical tasks.12 Nemati et al13 leveraged hidden Markov models and deep Q-networks to predict optimal heparin dosing for patients in intensive care units. Wang et al14 proposed supervised reinforcement learning with RNNs for treatment recommendation. However, all these existing works tried to learn a single model or a single policy for prediction.
Multiagent reinforcement learning is a subfield within reinforcement learning and has shown great success in the realm of Markov decision problems.15 In this work, we train 2 complementary RL agents to mimic the second-opinion situation. Dr. Agent has 2 policy gradient agents: a primary agent that observes patient current health condition to learn patient representation, while a secondary agent summarizes patient historical EHR to learn another patient representation. The different reinforcement agents can summarize long-term health history from complementary perspectives. We also design the reward function that transfers knowledge of disease progression. The learning of the agents is optimized by a long-term reward function related to the accuracy of clinical prediction, which can transfer observed progression trajectory knowledge to help the agents make correct predictions.
MATERIALS AND METHODS
As illustrated in Figure 1, Dr. Agent includes the following components: multiagent skip connection component, demographic information–enriched environment representation, and reward function for disease progression knowledge transfer. Next, we first introduce each component and then provide details of training and inference of Dr. Agent.
Figure 1.
The Dr. Agent model. We use the patient’s demographics to initialize the initial hidden state (ie, ) of the gated recurrent unit (GRU). We use 2 policy gradient agents, and , to extract long-term history information of patients by using a dynamic skip connections mechanism. At each time step, the 2 agents will generate an optimal state from history hidden states instead of using the last state to compute current state . We use the model’s output and predict target to compute long-term rewards to optimize the agents’ policy. FC: fully connected.
Multiagent skip connections
We adopt dynamic skip connection mechanism11 with gated recurrent unit (GRU) network16 to capture long-term dependencies among clinical events in patient EHR data.
Concretely, given the current visit at -th timestep, regular GRU cell will use and the hidden state of previous timestep to compute the current hidden state . To achieve dynamic skip connection, instead of using , we consider 2 policy gradient agents and (and for abbreviation) to generate an optimal from a history state set where hyperparameter determines the length of the observation window on the patients’ historical health status. The selection of depends on tasks. Larger means that the model can capture longer dependencies. In our experiment, the optimal is between 5 and 10 for most tasks.
We provide the 2 agents with different environment variables to make the agents observe patients health condition from complementary perspectives.
- Primary agent: Patients’ current lab tests and diagnosis provide direct information about their health status for clinicians to make treatment plans or prognosis. We use the primary agent to mimic such practice. The environment variable of the agent is the health status of patients’ current visit:
- Second-opinion agent: Accurate diagnoses and informing treatment based on prediction of individual patient risk for a negative health outcome can be difficult, especially for complex or rare diseases, and patients routinely seek second opinions for an understanding of their disease and their risk. Here, we design the second-opinion agent to be able to consider the entire patient history to obtain a more global view via summarizing patients’ historical visits by using historical hidden states of GRU as environment:
We hope the 2 agents focus on their own environment and learn more environment-related policies instead of using only 1 agent to process all information. This design can allow agents can comprehensively understand patients’ status and obtain complementary patient embeddings. Experiments show our multiagent strategy outperforms each single agent model.
For each agent , after observing the environment , the agent takes an action by selecting an optimal in historical hidden state set . is sampled from a multinomial distribution as follows:
where is 1 if , 0 otherwise. represents a 2-layer perceptron to transform to a vector with dimensionality . The softmax is to generate a probability distribution , and is the -th element in .
Agent and obtain the optimal states and by sampling and from the distribution and . The actions from 2 agents are combined as follows:
Here, we find that in some tasks, adding information in will have better performance. We use a hyperparameter to incorporate this situation.
Demographic enhancement
In clinical practice, the demographic information of patients such as gender and age can serve as important baseline data that depict the original health status and subtypes of patients. In naïve RNN models, the hidden state of the cell is initialized to zero by default. We use demographic information to initialize the hidden state of GRU as:
where . With a more reasonable initial hidden state, our model can better model patients’ follow-up health status representation.
Besides, patients’ demographics can also be an important context about historical health status and help the agents to decide which history visits are more important for prediction. To better incorporate this information, we concatenate the original environment representations of 2 agents with patients’ demographic data:
where refers to the concatenate operation. The 2 agents will observe the demographics enriched environment and generate using the aforementioned equations. Then the GRU cell uses and to compute output at -th timestep as: .
Finally, we use a 2-layer perceptron to compress the current hidden state and patients’ demographics. The final prediction output is calculated as:
The activation function of the output layer depends on the specific task. We use a softmax layer for multiclass classification tasks (ie, length of stay), a sigmoid layer for multilabel classification, and binary classification tasks (ie, decompensation, mortality, and phenotyping).
Reward function
The reward function evaluates the quality of the agents’ choice. We hope that the agents choose the optimal states that can help the model make correct prediction. In this study, to predict decompensation, mortality, and length of stay (LOS), we design the reward function at -th timestep as: . Because phenotype prediction can be casted as a multilabel classification task, in order to prevent the rewards from being affected by the number of labels, we design the reward function as: , where is the number of the patient’s phenotypes. The design of these reward functions allows agents to get higher rewards when they output a higher prediction for clinical events and get punishment when they fail to predict these events.
Furthermore, we hope that our agents have long-term visions, which means that the selected states are optimal for not only current predictions, but also future predictions. To achieve this, we transfer observed progression trajectory knowledge to help the agents make correct prediction by using long-term reward function. Therefore, instead of using the instantaneous reward , the long-term reward is calculated as:
where is the discount factor. It quantifies how much importance we value for future rewards. varies from 0 to 1. If is close to 0, the agent will tend to consider only instantaneous rewards. If is close to 1, the agent will consider future rewards with greater weight. Our experiments show that using long-term reward function allows the model to better extract long-term dependencies in patient data, thus predict more accurately.
Joint learning and policy optimization
Our first task is to optimize the parameters of 2 policy gradient agents and . We should also optimize the parameters of standard GRU and other parameters. For the latter task, the optimization is straightforward. We use the cross-entropy loss for our tasks, denoted as .
To train the agents, we maximize the expected reward under the skip policy distribution. For each agent at -th timestep, we have:
Furthermore, we want to prevent premature entropy collapse and encourage the policy to explore more diverse space. Besides, we also want to reduce the policy variance.17 We use an entropy term and subtract a baseline to tackle these issues:
is the on-policy value function.18
Because of the nondifferentiable of discrete skips, we use the REINFORCE method to optimize 19
Then, the final loss is as follows:
Data description and prediction tasks
Source of data, sample size, predictors, and data preprocessing
We use EHR data from the publicly available MIMIC-III database.2 We process the raw data according to Harutyunyan et al.20 We use the cohort of 33,678 unique patients. The raw data include 17 physiologic variables as predictors at each timestep, which is transformed into a 76-dimensional vector including numerical and one-hot encoded categorical clinical features. The demographics of each patient is a 12-dimensional vector including ethnicity, gender, age, height, and weight. Details of the predictors, data-preprocessing steps, and sample sizes are provided in the Supplementary Appendix.
Prediction tasks and outcomes
We perform 4 prediction tasks on the MIMIC-III dataset:
In-hospital mortality prediction predicts in-hospital mortality from observations recorded early in an intensive care unit admission. We formulated this task as binary classification using observations recorded from a limited window of time (eg, here 48 hours) following admission. The ground truth label is a binary value that indicates whether the patient died before hospital discharge.
Acute care phenotype classification is to classify which acute care conditions (of 25 chronic or critical conditions) are present in a given patient record. Because diseases can co-occur, this task is formulated as a multilabel classification problem. The target phenotypes are shown in the Supplementary Appendix.
Physiologic decompensation prediction involves the detection of patients who are physiologically decompensating, which means that conditions are deteriorating rapidly. The ground truth of the task is a binary label that indicates whether the patient’s date of death falls within the next 24 hours of the current time point.
Forecasting length of stay is to forecast the patient’s LOS. In this task, we predict the remaining LOS once per hour for every hour after admission. We formulate this task as a multiclass classification problem as previous works do.20,21 We divide the range of values into 10 buckets, 1 bucket for extremely short visits (<1 day), 7 day-long buckets for each day of the first week, and 2 buckets for outliers (one for stays of over 1 week but <2 weeks and one for stays of over 2 weeks).
Model evaluation
Baseline models for performance comparison
We compare Dr. Agent with the following baselines. These models focus on utilizing long-term dependencies in patients’ longitudinal EHR data and also have designed structure to utilize demographic information. It is worth noting that there are lots of state-of-the-art clinical prediction models that utilize extra modules such as convolution layer or de-correlation module to perform better prediction.7,22 However, their contribution is orthogonal to ours. We focus on better capturing temporal dependencies in EHR data. Our model can be easily combined with other extra modules.
Logistic regression (LR): We use LR with l2 regularization based on hand-engineered features described in Lipton et al23: for each variable, we compute 6 different sample statistic features on 7 different subsequences of a given time series. The per-subsequence features include minimum, maximum, mean, standard deviation, skew, and number of measurements. The 7 subsequences include the full time series, the first 10%/25%/50% of time and the last 50%/25%/10% of time. In total, we obtain 714 features per time series and we also use demographic features. Specifically, we trained a softmax regression model to solve the 10-class bucketed LOS prediction problem.
Multilayer perceptron (MLP): We use a 2-layer MLP network with l2 regularization to achieve prediction. The input features are same with LR model. Note that both MLP and LR cannot process time series data because they process each timestep as discrete data sample.
GRU: Patients’ demographics are duplicated at each timestep and simply concatenated to the original inputs.
Dipole 8 : Uses GRU with the additive attention mechanism.
T-LSTM 6 : Handles irregular time intervals by enabling time decay inside RNN cell, which makes older information less important. The original T-LSTM model is used for unsupervised clustering, and we modify it into a supervised learning model.
Simply Attend and Diagnose (SAnD)* 21 : Uses self-attention mechanism to predict clinical targets. We use the same hyperparameter settings in the original article. Different from the original article, we modify the padding strategy in SAnD. In order to prevent using future information, we use causal padding to obtain the input embedding.24
The implementation details of these models are shown in the Supplementary Appendix. We have made our codes available on a public repository (https://github.com/v1xerunt/Dr.Agent). We also compare Dr. Agent against its reduced models such as the performance difference between long-term reward function and instantaneous reward function. The results are shown in the Supplementary Appendix.
Evaluation strategy
We truncate the length of samples to a reasonable limit (ie, 400) for all tasks except for in-hospital mortality prediction. We fix a test set of 15% of patients and divide the rest of the dataset into training set and validation set with a proportion of 85%:15%. We fix the best model on the evaluation set within 40 epochs and report the performance in the test set. To estimate a 95% confidence interval, we resample the test set 10 000 times, calculate the score on the resampled sets, and then use 2.5th and 97.5th percentiles of these scores as our confidence interval estimate.
For in-hospital mortality and decompensation prediction tasks, we use the area under the receiver-operating characteristic curve (AUROC), area under precision-recall curve (AUPRC), and Min(Re, P+) as our evaluation metrics. Min(Re, P+) is calculated as the maximum of minimum (recall, precision) on the precision-recall curve, and we provide an illustration example in the Supplementary Appendix. For the phenotype classification task, we consider the following metrics: macro-averaged AUROC, which averages per-label AUROC; micro-averaged AUROC, which computes single AUROC score for all classes together; and weighted AUROC, which takes disease prevalence into account. For the LOS prediction task, we use the Cohen’s linear weighted kappa metric, which measures the interagreement between true and predicted labels. We assign the mean LOS from each bin to the samples assigned to that class, and use mean squared error and mean absolute percentage error as evaluation metrics.
We also assess model calibration visually using a calibration plot and statistically by estimating the calibration slope and intercept,25,26 except for the LOS task, as the LOS prediction is a multiclass classification task. The closer the slope is to 1, the better the model is calibrated. We compute the average calibration performance of all phenotypes for the phenotyping task.
RESULTS
Table 1 shows the performance of our model and all baseline models. We also report the parameters of each model in the LOS task. Dr. Agent outperforms all baseline models on phenotyping, mortality prediction, and decompensation prediction tasks. On the decompensation task, the relative improvement of Dr. Agent is up to 15% in AUPRC and 6% in Min(Re, P+) compared with the best baseline model T-LSTM and Dipole. On the mortality prediction task, Dr. Agent achieves more than 4% higher AUPRC compared with the best baseline GRU. On the phenotyping task, Dr. Agent achieves 1% relative improvement on AUC compared with the best baseline SAnD. On the LOS task, Dr. Agent achieves 1% higher kappa score and lower mean squared error compared with the best baseline model, while the mean absolute percentage error of Dr. Agent is slightly lower than the GRU (<1%). Among all baselines in Table 1, Dipole and SAnD achieve better performance in most cases because of better handling long-term dependencies. LR and MLP show low performance, as they cannot capture temporal patterns in longitudinal EHR records. The detailed performance about per-phenotype performance in the phenotyping task is shown in the Supplementary Appendix.
Table 1.
Performance comparison for the MIMIC-III tasks
Task 1: Phenotyping | |||
---|---|---|---|
Model | Macro AUC | Micro AUC | Weighted AUC |
LR | 0.7285 (0.7257-0.7317) | 0.7755 (0.7727-0.7786) | 0.7242 (0.7209-0.7277) |
MLP | 0.7371 (0.7342-0.7400) | 0.7863 (0.7832-0.7895) | 0.7311 (0.7282-0.7340) |
GRU | 0.7684 (0.7650-0.7719) | 0.8174 (0.8147-0.8202) | 0.7601 (0.7574-0.7629) |
Dipole | 0.7484 (0.7453-0.7517) | 0.8012 (0.7979-0.8045) | 0.7358 (0.7328-0.7389) |
T-LSTM | 0.7579 (0.7550-0.7609) | 0.8091 (0.8062-0.8122) | 0.7524 (0.7497-0.7552) |
SAnD* | 0.7735 (0.7696-0.7776) | 0.8232 (0.8205-0.8258) | 0.7660 (0.7631-0.7689) |
Dr. Agent | 0.7821 (0.7792-0.7850) | 0.8276 (0.8242-0.8311) | 0.7748 (0.7717-0.7779) |
Task 2: Mortality prediction | |||
Model | AUROC | AUPRC | Min(Re, P+) |
LR | 0.8462 (0.8282-0.8641) | 0.4797 (0.4413-0.5184) | 0.4833 (0.4445-0.5219) |
MLP | 0.8521 (0.8326-0.8722) | 0.4851 (0.4547-0.5162) | 0.4849 (0.4517-0.5186) |
GRU | 0.8600 (0.8401-0.8709) | 0.5104 (0.4751-0.5458) | 0.4977 (0.4662-0.5293) |
Dipole | 0.8628 (0.8492-0.8769) | 0.4989 (0.4688-0.5293) | 0.5026 (0.4709-0.5350) |
T-LSTM | 0.8617 (0.8480-0.8749) | 0.4964 (0.4617-0.5308) | 0.4977 (0.4606-0.5337) |
SAnD* | 0.8382 (0.8188-0.8577) | 0.4545 (0.4144-0.4941) | 0.4885 (0.4516-0.5246) |
Dr. Agent | 0.8658 (0.8499-0.8822) | 0.5311 (0.4955-0.5662) | 0.5054 (0.4724. 0.5393) |
Task 3: Decompensation prediction | |||
Model | AUROC | AUPRC | Min(Re, P+) |
LR | 0.8733 (0.8704-0.8762) | 0.2233 (0.2179-0.2295) | 0.2374 (0.2301-0.2453) |
MLP | 0.8779 (0.8741-0.8807) | 0.2291 (0.2225-0.2359) | 0.2350 (0.2270-0.2428) |
GRU | 0.9013 (0.8979-0.9048) | 0.2611 (0.2536-0.2682) | 0.3212 (0.3138-0.3290) |
Dipole | 0.8970 (0.8933-0.9008) | 0.2626 (0.2556-0.2692) | 0.3340 (0.3261-0.3415) |
T-LSTM | 0.8956 (0.8925-0.8987) | 0.2670 (0.2591-0.2748) | 0.3275 (0.3194-0.3352) |
SAnD* | 0.8953 (0.8924-0.8988) | 0.2353 (0.2277-0.2432) | 0.3051 (0.2979-0.3127) |
Dr. Agent | 0.9071 (0.9037-0.9103) | 0.3074 (0.3004-0.3149) | 0.3531 (0.3450-0.3611) |
Task 4: LOS prediction | |||
Model | Kappa | MSE | MAPE |
LR | 0.3504 (0.3477-0.3534) | 47928 (47557-48335) | 237.4 (234.9-239.7) |
MLP | 0.3577 (0.3545-0.3599) | 46804 (46227-47364) | 230.0 (227.4-233.1) |
GRU | 0.3864 (0.3831-0.3898) | 39310 (38976-39651) | 187.2 (185.4-189.2) |
Dipole | 0.3965 (0.3935-0.3996) | 38971 (38627-39321) | 193.2 (191.5-194.6) |
T-LSTM | 0.3872 (0.3843-0.3902) | 40399 (40042-40756) | 189.7 (188.6-191.0) |
SAnD* | 0.3876 (0.3845-0.3908) | 39719 (39338-40094) | 191.8 (190.5-193.2) |
Dr. Agent | 0.3988 (0.3955-0.4022) | 38446 (38084-38816) | 188.8 (186.9-190.5) |
AUC: area under the curve; AUROC: area under the receiver-operating characteristic curve; AUPRC: area under precision-recall curve; GRU: gated recurrent unit; LOS: length of stay; LR: logistic regression; MAPE: mean absolute percentage error; MIMIC-III: Medical Information Mart for Intensive Care III; MLP: multilayer perceptron; MSE: mean squared error; SAnD: Simply Attend and Diagnose; T-LSTM: time-aware long short-term memory network.
Figures 2 and 3 show the calibration plot of decompensation and mortality prediction tasks. In the calibration plot, the predictions are grouped into bins based on their predicted probabilities. For each bin, the y-axis indicates the proportion of true outcomes, and the x-axis is the mean predicted probability. Table 2 shows the calibration performance for phenotyping, mortality, and decompensation prediction tasks, and the detailed per-phenotype performance is shown in the Supplementary Appendix. The results show that Dr. Agent achieves the best calibration performance on all 3 tasks. For the baseline models, they are generally well calibrated on the mortality prediction task and also have better calibration performance compared with other tasks. For phenotyping task, Dipole and SAnD* achieve the worst calibration performance, whereas GRU is better calibrated (slope between 0.9 and 1). For decompensation prediction, none of the baseline models are well calibrated (slope < 0.8) and Dipole, LR, and T-LSTM have the worst calibration (slope < 0.7).
Figure 2.
The calibration plot for the decompensation prediction. GRU: gated recurrent unit; LR: logistic regression; MLP: multilayer perceptron; SAnD: Simply Attend and Diagnose; T-LSTM: time-aware long short-term memory network.
Figure 3.
The calibration plot for the mortality prediction. GRU: gated recurrent unit; LR: logistic regression; MLP: multilayer perceptron; SAnD: Simply Attend and Diagnose; T-LSTM: time-aware long short-term memory network.
Table 2.
Calibration performance comparison for the MIMIC-III tasks
Task 1: Phenotyping | ||
---|---|---|
Model | Slope | Intercept |
LR | 0.7149 (0.7062 to 0.7235) | 0.0138 (0.0024 to 0.0263) |
MLP | 0.7155 (0.7132 to 0.7175) | 0.0164 (0.0291 to 0.0029) |
GRU | 0.9136 (0.9085 to 0.9192) | 0.0253 (0.0115 to 0.0402) |
Dipole | 1.6745 (1.6682 to 1.6810) | 0.1620 (0.1788 to 0.1485) |
T-LSTM | 0.8211 (0.8119 to 0.8308) | 0.0285 (0.0104 to 0.0467) |
SAnD* | 0.6773 (0.6652 to 0.6891) | 0.0791 (0.0629 to 0.0918) |
Dr. Agent | 0.9947 (0.9883 to 1.0025) | −0.0080 (-0.0196 to 0.0035) |
Task 2: Mortality prediction | ||
Model | Slope | Intercept |
LR | 0.9125 (0.9048 to 0.9205) | 0.0003 (-0.0074 to 0.0082) |
MLP | 1.1347 (1.1268 to 1.1426) | 0.0075 (-0.0010 to 0.0168) |
GRU | 0.9011 (0.8897 to 0.9144) | −0.0253 (-0.0392 to -0.0127) |
Dipole | 0.8658 (0.8530 to 0.8777) | 0.0014 (-0.0083 to 0.0115) |
T-LSTM | 1.0744 (1.0612 to 1.0870) | −0.0480 (-0.0569 to -0.0392) |
SAnD* | 0.9291 (0.9185 to 0.9406) | −0.0191 (-0.0301 to -0.0079) |
Dr. Agent | 0.9774 (0.9688 to 0.9854) | 0.0060 (-0.0015 to 0.0141) |
Task 3: Decompensation prediction | ||
Model | Slope | Intercept |
LR | 0.6762 (0.6738 to 0.6789) | −0.0463 (-0.0588 to -0.0431) |
MLP | 0.7028 (0.6999 to 0.7057) | −0.0542 (-0.0675 to -0.0429) |
GRU | 0.7976 (0.7945 to 0.8002) | −0.0902 (-0.1004 to -0.0798) |
Dipole | 0.5783 (0.5760 to 0.5804) | −0.0659 (-0.0740 to -0.0565) |
T-LSTM | 0.6925 (0.6897 to 0.6954) | −0.0707 (-0.0793 to -0.0614) |
SAnD* | 0.7048 (0.7020 to 0.7073) | −0.0495 (-0.0586 to -0.0401) |
Dr. Agent | 0.9599 (0.9561 to 0.9634) | −0.0632 (-0.0745 to -0.0522) |
GRU: gated recurrent unit; LOS: length of stay; LR: logistic regression; MIMIC-III: Medical Information Mart for Intensive Care III; MLP: multilayer perceptron; T-LSTM: time-aware long short-term memory network.
DISCUSSION
Here, we explore how our agents choose optimal historical states. Taking the decompensation task as an example, we extracted all the decompensated visits of all patients and compute the average of each visit selected by agents. A larger indicates the agent select a more distant visit as optimal choice and vice versa. Each average represents the select pattern of agents for each patient. We will observe how different agents make choices. All observations are made on the test dataset.
In Table 3, we compute the average and standard deviation of chosen by and in Dr. Agent. We can observe that tends to select larger than . This may be because the environment of the 2 agents is different: uses the patient’s current health status as environment, whereas uses the patient’s historical health status. This leads to being more sensitive to variations of patients’ status, so it is more inclined to choose a relatively smaller as the optimal choice.
Table 3.
Selection patterns of 2 agents in the Dr. Agent model
Agent | Average ± SD of |
---|---|
5.72 ± 2.91 | |
4.97 ± 2.83 |
We also explore how the settings of different environments and reward functions affect the agents’ choices. The results and discussions are shown in the Supplementary Appendix.
LIMITATIONS OF PROPOSED METHOD
Deep learning–based models have shown great success in different clinical tasks because they have strong ability to process high-dimensional longitudinal EHR data. Though they can achieve better performance compared with traditional statistical methods, interpretability is a major limitation for these deep learning methods including Dr. Agent, as deep learning models are considered as black box. There is recent work focusing on improving interpretability of deep learning models.7,27 Dr. Agent can potentially combine with those models to provide model explanation in future. Another drawback of Dr. Agent is the cold-start issue (ie, our model may have lower accuracy when applied to patients with very few visits), which is also a common issue for recurrent-based methods.28,29 According to the benchmark requirements,20 we excluded patients with few visits or high missing rate (detailed exclude rules are shown in the Supplementary Appendix), which may affect model generalizability. In future works, this issue may be solved using transfer learning or introducing expert knowledge in the model.
CONCLUSION
In this work, we propose a reinforcement learning-based clinical predictive model, Dr. Agent. Dr. Agent has 2 policy gradient agents (ie, a primary agent and a second-opinion agent) to mimic the practice that patients seek second options from multiple clinical experts for diagnosis. The 2 agents achieve dynamic skip connection mechanism, which can dynamically choose the optimal hidden states from the past few states of GRU. Dr. Agent can comprehensively model the long-term dependencies of patients’ health status while considering patients’ demographics. Furthermore, we use the long-term rewards function, which makes Dr. Agent have the ability to transfer knowledge of disease progression. In addition, Dr. Agent has a shorter gradient backpropagation path; thus, the model can alleviate the challenges of vanishing gradient. Dr. Agent achieves better performance than state-of-the-art methods on multiple prediction tasks including decompensation, LOS, mortality, and most of the phenotyping tasks on the MIMIC-III dataset. We hope that our model can help physicians identify the patients at high risk to prevent or delay the adverse outcome.
FUNDING
This work was in part supported by National Science Foundation awards IIS-1418511, CCF-1533768 and IIS-1838042 (to JS), and National Institutes of Health awards R01 1R01NS107291-01 and R56HL138415 (to JS).
AUTHOR CONTRIBUTIONS
JG implemented the method and conducted the experiments. All authors were involved in developing the ideas and writing the article.
SUPPLEMENTARY MATERIAL
Supplementary material is available at Journal of the American Medical Informatics Association online.
CONFLICT OF INTEREST STATEMENT
The authors have no competing interests to declare.
Supplementary Material
REFERENCES
- 1.Mayo Clinic. The value of second opinions demonstrated in study. Science Daily2017. https://www.sciencedaily.com/releases/2017/04/170404084442.htm Accessed December 20, 2019.
- 2. Johnson AE, Pollard TJ, Shen L, et al. MIMIC-III, a freely accessible critical care database. Sci Data 2016; 3 (1): 16–35. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 3. Choi E, Bahadori MT, Searles E, et al. Multi-layer representation learning for medical concepts. In: proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining; 2016: 1495–504; San Francisco, CA.
- 4. Shang J, Xiao C, Ma T, Li H, Sun J. GAMENet: Graph augmented memory networks for recommending medication combination. In: proceedings of the AAAI Conference on Artificial Intelligence; 2019: 1126–33; Honolulu, HI.
- 5. Gong JJ, Naumann T, Szolovits P, Guttag JV. Predicting clinical outcomes across changing electronic health record systems. In: proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining; 2017: 1497–505; Halifax, NS, Canada.
- 6. Baytas IM, Xiao C, Zhang X, Wang F, Jain AK, Zhou J. Patient subtyping via time-aware LSTM networks. In: proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining; 2017: 65–74; Halifax, NS, Canada.
- 7. Ma L, Gao J, Wang Y, et al. AdaCare: explainable clinical health status representation learning via scale-adaptive feature extraction and recalibration. arXiv 1911.12205; 2019.
- 8. Ma F, Chitta R, Zhou J, You Q, Sun T, Gao J. Dipole: diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In: proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining; 2017: 1903–11; Halifax, NS, Canada.
- 9. Esteban C, Staeck O, Baier S, Yang Y, Tresp V. Predicting clinical events by combining static and dynamic information using recurrent neural networks. In: 2016 IEEE International Conference on Healthcare Informatics; 2016: 93–101; Chicago, IL.
- 10. Lee W, Park S, Joo W, Moon I-C. Diagnosis prediction via medical context attention networks using deep generative modeling. In: 2018 IEEE International Conference on Data Mining (ICDM); 2018.
- 11. Gui T, Zhang Q, Zhao L, et al. Long short-term memory with dynamic skip connections. In: proceedings of the 33rd AAAI Conference on Artificial Intelligence (AAAI-19); 2019: 6481–8.
- 12. Gottesman O, Johansson F, Meier J, et al. Evaluating reinforcement learning algorithms in observational health settings. arXiv 1805.12298v1; 2018.
- 13. Nemati S, Ghassemi MM, Clifford GD.. Optimal medication dosing from suboptimal clinical examples: A deep reinforcement learning approach. Conf Proc IEEE Eng Med Biol Soc 2016; 2016: 2978–81. [DOI] [PubMed] [Google Scholar]
- 14. Wang L, Zhang W, He X, Zha H. Supervised reinforcement learning with recurrent neural network for dynamic treatment recommendation. In: proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining; 2018: 2447–56.
- 15. Shoham Y, Powers R, Grenager T. Multi-agent reinforcement learning: a critical survey. 2003. https://www.cc.gatech.edu/classes/AY2009/cs7641_spring/handouts/MALearning_ACriticalSurvey_2003_0516.pdf Accessed December 10, 2019.
- 16. Chung J, Gulcehre C, Cho K, Bengio Y. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv 1412.3555; 2014.
- 17. Nachum O, Norouzi M, Xu K, Schuurmans D. Bridging the gap between value and policy based reinforcement learning. In: proceedings of the 31st International Conference on Neural Information Processing Systems; 2017: 2772–82.
- 18. Sutton RS, Barto AG.. Reinforcement Learning: An Introduction. Cambridge, MA: MIT Press; 2011. [Google Scholar]
- 19. Williams RJ. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Mach Learn 1992; 8 (3–4): 229–56. [Google Scholar]
- 20. Harutyunyan H, Khachatrian H, Kale DC, Galstyan A.. Multitask learning and benchmarking with clinical time series data. Sci Data 2019; 6: 96. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 21. Song H, Rajan D, Thiagarajan JJ, Spanias A. Attend and diagnose: clinical time series analysis using attention models. arXiv 1711.03905v2; 2017.
- 22. Ma L, Zhang C, Wang Y, et al. ConCare: personalized clinical feature embedding via capturing the healthcare context. arXiv 1911.12216v1; 2019.
- 23. Lipton ZC, Kale DC, Elkan C, Wetzel R. Learning to diagnose with LSTM recurrent neural networks. arXiv 1511.03677; 2015.
- 24. Oord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio. arXiv 1609.03499v2; 2016.
- 25. Steyerberg EW, Vickers AJ, Cook NR, et al. Assessing the performance of prediction models: a framework for some traditional and novel measures. Epidemiology 2010; 21 (1): 128–38. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 26. Cox DR. Two further applications of a model for binary regression. Biometrika 1958; 45 (3-4): 562–5. [Google Scholar]
- 27. Choi E, Bahadori MT, Sun J, Kulas J, Schuetz A, Stewart W. Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. In: proceedings of the 30th International Conference on Neural Information Processing Systems; 2016: 3512–20.
- 28. Volkovs M, Yu G, Poutanen T. Dropoutnet: Addressing cold start in recommender systems. In: proceedings of the 31st International Conference on Neural Information Processing Systems; 2017: 4964–73.
- 29. Bansal T, Belanger D, McCallum A. Ask the gru: multi-task learning for deep text recommendations. In: proceedings of the 10th ACM Conference on Recommender Systems; 2016: 107–14.
Associated Data
This section collects any data citations, data availability statements, or supplementary materials included in this article.