Abstract
Deep learning for causal inference is a promising technique that leverages deep neural networks to infer counterfactuals and estimate treatment effects. Liu et al. proposed CURE (causal treatment effect estimation), a new pre-training and fine-tuning framework for treatment effects estimation using large-scale patient data.
Deep learning for causal inference is a promising technique that leverages deep neural networks to infer counterfactuals and estimate treatment effects. Liu et al. proposed CURE (causal treatment effect estimation), a new pre-training and fine-tuning framework for treatment effects estimation using large-scale patient data.
Main text
Causal questions arise frequently in many domains, including education, public policy, economics, healthcare, etc. Answering causal questions requires solving tasks such as identifying causal relationships between covariates, estimating the treatment effects of an intervention, and inferring counterfactual outcomes. Experimental studies have been commonly recognized as the gold standard for answering various types of causal questions, such as randomized controlled trials in healthcare and A/B testing in online advertising. However, experimental studies usually suffer from cost, time, and ethical constraints in practice. On the other hand, observational studies offer a promising solution for estimating causal effects by only analyzing observed data without applying any new interventions to individuals.
Following the seminal work of the potential outcome framework (POM) by Rubin,1 numerous statistical methods have been developed to estimate treatment effects from observational data, such as matching methods, reweighting methods, etc. Counterfactuals, a key notion in the POM, have been well-received by the research community, and significant efforts have been devoted to inferring counterfactuals with theoretical guarantees.2 Owing to their easy implementation and intuitive model interpretations, POM-based causal inference methods, such as propensity score matching, have been widely adopted by domain experts. Meanwhile, researchers have found that many traditional casual inference methods may not effectively handle high-dimensional datasets due to their limited modeling capability.
To address these limitations, researchers identified some connections between causal inference and machine learning and proposed various machine learning approaches for the treatment effect estimation task.3,4 Some early methods leveraged tree-based methods5 to predict counterfactuals or employed subspace learning and variable selection methods6 to seek intermediate feature spaces, which facilitate causal inference tasks such as matching. Lately, deep neural networks have demonstrated remarkable performance in predicting counterfactuals at the group level or individual level for discrete or continuous treatments.7,8 These findings further motivated researchers to explore advanced deep learning architectures for causal inference, such as generative adversarial networks, graph neural networks, and transformers.
The use of transformers for treatment effect estimation brings an attractive but challenging research question: would it be possible to train large foundation models for treatment effect estimation in healthcare? Having a large foundation model with emergent capabilities for treatment effect estimation could significantly assist clinical studies and practices, saving costs and labor. Transformers are core building blocks of state-of-the-art large language models (e.g., ChatGPT, Llama) and large multimodal models (e.g., GPT-4V, Gemini) and will likely help shape the future of artificial general intelligence. Adapting the knowledge and experience of training large models from the general language or vision domain to the healthcare domain is nontrivial. In the language and vision domains, the pre-training and downstream tasks are well defined and are associated with abundant well-labeled data. For healthcare, however, it is quite difficult to obtain large-scale, well-labeled patient data and even more challenging to design reasonable strategies for model pre-training and fine-tuning.
To address these challenges, Liu et al. proposed the causal treatment effect estimation (CURE),9 a new pre-training and fine-tuning framework for estimating the causal effect of a treatment. They propose novel strategies for the entire process of developing large models for treatment effect estimation from the perspectives of data collection, data encoding, pre-training methods, and downstream tasks for model fine-tuning.
High-quality data are always a key factor in training large models. Liu et al. acquired and pre-processed large-scale patient data (i.e., about three million unlabeled patient sequences) from real-world medical claims containing individual-level, de-identified healthcare information from employers, health plans, and hospitals. Considering the unique characteristics of the longitudinal patient data, they proposed a new approach to effectively encode hierarchically structured patient data. In particular, patient data were flattened by chronologically going through each medication and diagnosis in each visit, and each medication or diagnosis is treated as a token. This new encoding approach enables a comprehensive representation of patient information.
The pre-training of the CURE framework is different from the traditional procedures used by the BERT-based models due to the complex hierarchical structure and irregularity of the observational patient data. Liu et al. propose a new design of the embedding layer, incorporating the associated code type information and time information that are crucial for the treatment effect estimation task. Then, the masked language modeling loss is adopted to pre-train the model. Subsequently, the fine-tuning of CURE involves four downstream treatment effect estimation tasks from four randomized clinical trials with small-scale labeled datasets, following a retrospective study design. These tasks are about evaluating the comparative effectiveness of two treatment effects in reducing the risk of stroke for patients with coronary artery disease. Small-scale labeled datasets, about 10,000 to 20,000 samples per task in total, are used in the fine-tuning stage.
Liu et al. conducted extensive experiments and demonstrated that their CURE framework outperforms the state-of-the-art treatment effect estimation methods in all four downstream tasks. For instance, the improvements over baselines are about 4% and 7% in terms of the area under the receiver operating characteristic curve score and the area under the precision-recall curve score, respectively. These results, verified by the published randomized clinical trials, showcase the effectiveness of the CURE framework for treatment effect estimation. Liu et al. have also conducted ablation studies to reveal the impacts of different model configurations or experimental settings, such as the patient embedding strategies and the size of pre-training data. These studies further justified the rationality and effectiveness of the proposed data encoding, pre-training, and fine-tuning strategies.
Overall, the work by Liu et al. contributes to the literature on casual inference and the development of large models with domain-specific applications. In practice, the CURE framework could be adopted as a complementary tool to help the hypothesis generation of treatment effects for randomized clinical trials. From the technical point of view, the newly proposed encoding approach could potentially be applied to a wide range of biomedical studies beyond the treatment effect estimation task. In addition, the pre-training and fine-tuning approaches are well suited for a wide range of healthcare applications.
As I look forward, the development of large foundation models for treatment effect estimation in healthcare could benefit other biomedical studies as well, such as inferring causal relations among patient covariates (i.e., causal discovery),10 improving the performance of patient outcome prediction with causal knowledge, enhancing the model explainability with causal pathways in patient data and models, etc. In addition, I firmly believe that many remaining research questions require continued efforts from the interdisciplinary research teams in machine learning, data science, statistics, and healthcare. For instance, a coherent way of dealing with complex treatments (e.g., sequential treatments and structured treatments) is still unclear. The ChatGPT moment for causal inference is not here yet, but the recent research progress sheds light on our way forward. Altogether, these investigations and insights could inspire researchers in causal inference and healthcare.
Acknowledgments
This work is supported in part by the Agency for Healthcare Research and Quality under grant R01HS029009 and the National Science Foundation under grant IIS-2316306.
Declaration of interests
The author declares no competing interests.
References
- 1.Rubin D.B. Causal inference using potential outcomes: Design, modeling, decisions. J. Am. Stat. Assoc. 2005;100:322–331. [Google Scholar]
- 2.Stuart E.A. Matching methods for causal inference: A review and a look forward. Stat. Sci. 2010;25:1–21. doi: 10.1214/09-STS313. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 3.Yao L., Chu Z., Li S., Li Y., Gao J., Zhang A. A survey on causal inference. ACM Trans. Knowl. Discov. Data. 2021;15:1–46. [Google Scholar]
- 4.Feuerriegel S., Frauen D., Melnychuk V., Schweisthal J., Hess K., Curth A., Bauer S., Kilbertus N., Kohane I.S., van der Schaar M. Causal machine learning for predicting treatment outcomes. Nat. Med. 2024;30:958–968. doi: 10.1038/s41591-024-02902-1. [DOI] [PubMed] [Google Scholar]
- 5.Wager S., Athey S. Estimation and inference of heterogeneous treatment effects using random forests. J. Am. Stat. Assoc. 2018;113:1228–1242. [Google Scholar]
- 6.Li S., Fu Y. Matching on balanced nonlinear representations for treatment effects estimation. Adv. Neural Inf. Process. Syst. 2017;30:930–940. [Google Scholar]
- 7.Yao L., Li S., Li Y., Huai M., Gao J., Zhang A. Representation learning for treatment effect estimation from observational data. Adv. Neural Inf. Process. Syst. 2018;31:2638–2648. [Google Scholar]
- 8.Schwab P., Linhardt L., Bauer S., Buhmann J.M., Karlen W. Vol. 34. 2020. Learning counterfactual representations for estimating individual dose-response curves; pp. 5612–5619. (Proceedings of the AAAI Conference on Artificial Intelligence). [Google Scholar]
- 9.Liu R., Chen P.-Y., Zhang P. CURE: A deep learning framework pre-trained on large-scale patient data for treatment effect estimation. Patterns. 2024;5 doi: 10.1016/j.patter.2024.100973. [DOI] [Google Scholar]
- 10.Wan G., Wu Y., Hu M., Chu Z., Li S. Bridging causal discovery and large language models: A comprehensive survey of integrative approaches and future directions. arXiv. 2024 doi: 10.48550/arXiv.2402.11068. Preprint at. [DOI] [Google Scholar]
