Abstract
Accurately predicting a patient’s risk of progressing to late age-related macular degeneration (AMD) is difficult but crucial for personalized medicine. While existing risk prediction models for progression to late AMD are useful for triaging patients, none utilizes longitudinal color fundus photographs (CFPs) in a patient’s history to estimate the risk of late AMD in a given subsequent time interval. In this work, we seek to evaluate how deep neural networks capture the sequential information in longitudinal CFPs and improve the prediction of 2-year and 5-year risk of progression to late AMD. Specifically, we proposed two deep learning models, CNN-LSTM and CNN-Transformer, which use a Long-Short Term Memory (LSTM) and a Transformer, respectively with convolutional neural networks (CNN), to capture the sequential information in longitudinal CFPs. We evaluated our models in comparison to baselines on the Age-Related Eye Disease Study, one of the largest longitudinal AMD cohorts with CFPs. The proposed models outperformed the baseline models that utilized only single-visit CFPs to predict the risk of late AMD (0.879 vs 0.868 in AUC for 2-year prediction, and 0.879 vs 0.862 for 5-year prediction). Further experiments showed that utilizing longitudinal CFPs over a longer time period was helpful for deep learning models to predict the risk of late AMD. We made the source code available at https://github.com/bionlplab/AMD_prognosis_mlmi2022 to catalyze future works that seek to develop deep learning models for late AMD prediction.
Keywords: Age-related macular degeneration, Deep learning, Convolutional neural networks, Recurrent neural networks, Transformer
1. Introduction
Age-related macular degeneration (AMD) is the leading cause of vision loss and severe vision impairment [2,16,18], which is projected to affect approximately 288 million people in the world by 2040 [21]. The annual healthcare cost incurred by AMD is about $4.6 billion in the United States imposing extreme burdens on both patients and healthcare systems.
Traditional assessments of AMD severity have been heavily dependent on the manual analysis of color fundus photographs (CFPs), which is similar to a photographic record of an ophthalmologists clinical fundus examination at the slit lamp [3]. CFPs are captured by fundus cameras and are analyzed by experts who can assess AMD severity based on multiple characteristics of a macula (e.g., presence, type, and size of drusen) [5]. The most widely used method to assess AMD severity and predict the risk of progression to late AMD is the simplified AMD severity score that was developed by the Age-Related Eye Disease Study (AREDS) Research Group [3]. The score is calculated based on the macular characteristics of CFP (or on clinical examination) from both eyes at a single time-point, and classifies an individual into 0–5 on severity scale. This severity score has been used as the current clinical standard in assessing an individuals risk of progression into late stage AMD (late AMD), based on published 5-year risks of progression to late AMD that increase from steps 0–4. However, the current clinical standard cannot incorporate longitudinal data for late AMD prediction.
While the characteristics and disease mechanisms of AMD have been well studied, there is no approved therapy for geographic atrophy that prevents slow progression of AMD to vision loss. The onset and progression of AMD can be heterogeneous between patients. Some individuals progress from early to inter mediate and to late AMD more rapidly, while others progress to late AMD more slowly [17]. Such heterogeneity requires personalized treatment planning that may be helpful in justifying medical and lifestyle interventions, vigilant home monitoring, frequent reimaging, and in planning shorter but highly powered clinical trials [15]. Meanwhile, accurately predicting late AMD risk is equally important since patient data are diverse and contain different time intervals between visits. For example, patients with early AMD could develop late AMD within a time range from a few years to several decades [13]. This requires a predictive model that focuses heavily on temporal information, which is crucial to understand disease progression [19].
To date, many machine learning methods have been used to predict late AMD onset [8], such as Logistic Regression [14], AdaBoost [1], XGBoost [23] and Random Forest [11], where the majority of these approaches use static structured features based on data from a single time-point. While these methods are straightforward and easy to implement, they typically do not capture the temporal progression information contained naturally in the data. Deep learning models have been successfully adopted in healthcare and medical tasks, and a certain amount of work has used Convolutional Neural Networks (CNNs) for addressing AMD image data [22]. Ghahramani et al. recently proposed a framework that combines a CNN and Recurrent Neural Networks (RNN), tailored for capturing temporal progression information from CFPs for late AMD prediction [4]. However, the authors considered the data up to only three years, without the robustness of applying the model on various intervals of patient visits [13].
In this study, we used longitudinal CFPs to predict an eye as having progressed to late AMD within certain periods (2-year and 5-year). Both 2-year and 5-year predictions are common clinical scenarios. These periods were selected in advance, and it was relative to the time when the fundus photograph was taken, not to the time of the baseline visit. Specifically, for one eye, inputs are all historical CFPs and output is the probability of progression to late AMD within the specific time periods.
We proposed two deep learning models that utilize CNN with RNN and Transformer encoder respectively. We used a ResNet fine-tuned on a late AMD detection task as a fixed feature extractor, then applied a Long-Short term memory (LSTM) [10] or a Transformer encoder [20] to predict the 2-year or 5-year risk of late AMD progression. We trained and evaluated the models using longitudinal CFPs from the AREDS [7]. Models were evaluated using the area under the receiver operating characteristic curve (AUC). We compared our model with a plain ResNet, which predicts AMD progression using a single image.
Our contributions can be summarized as follows: 1) We proposed two deep learning models (CNN-LSTM and CNN-Transformer) to predict 2-year and 5-year risk of progression to late AMD; 2) We proposed a sequence unrolling strategy that can be applied to individuals having various length of sequences; 3) We showed that the proposed deep learning models outperformed baseline models that utilize only a single CFP to predict the risk of progression to late AMD; 4) We showed that longitudinal CFPs over a longer period increased the predictive performance of deep learning models to predict the risk of progression to late AMD. We also made the source code available at https://github.com/bionlplab/AMD_prognosis_mlmi2022 to catalyze future works.
2. Methods
2.1. Definition of Late AMD Progression Prediction Task
We first formulated late AMD progression prediction task (Fig 2A). Let T* be the ‘true’ time to late AMD for one participant in a study and C is right-censoring time (e.g., the end of the study). In the discrete context, we have disjoint intervals {t0, t1, …, tT}, where T = min(T*, C) is the observed event time. Since the sequence length varies between individuals, we proposed a sequence unrolling strategy of each individual to consider all visit length (Fig 2B). Specifically, we construct (I0, …, IT) where Il = (t0, tl] is a sub-sequence of the entire observation period. In this way, a model can consider all possible length of sequences within an individual’s entire history.
Fig. 2.

Model architectures. A. CNN-LSTM. B. CNN-Transformer.
Given these definitions, at time tl, our model predicts the risk of late AMD in the prediction window (tl, tl+n] with longitudinal features in Ii. Here, n is the pre-selected inquiry duration. In other words, the label at time tl is 1 if tl+n ≥ T*; otherwise 0. In this study, we focused on 2-year (n = 2) and 5-year (n = 5) prediction because they are two common clinical scenarios [15].
This task is similar to the conditional survival probability , meaning that a participant will survive an additional n years given a survival history of tl years. S(t) = Pr(z > t) is the survival probability and z is the time for the event of late AMD.
2.2. CNN for Feature Extraction
To extract image features from CFPs, we trained a CNN model on the late AMD detection task: a task to classify a given CFP into binary label (late AMD or non-late AMD). Then, we treated this CNN as a fixed feature extractor. The CNN was only used for feature extraction, not trained with other layers in CNN-LSTM and CNN-Transformer. There are a number of existing CNN architectures. In this study, we used ResNet [9] since it outperforms other CNN architectures in predicting AMD severity score [6]. The image features have the size of 2,048.
2.3. CNN-LSTM
In the CNN-LSTM model (Fig 1A), we used the pre-trained ResNet (Sect. 2.2) to extract features {f0, …, fl} from an individual’s longitudinal CFPs in the observation window. A single fully connected layer was used to reduce dimensionality of the extracted features to the size of 256. The features were then fed into a single layer LSTM. Finally, we utilized the last output representation of LSTM for prediction.
Fig. 1.

Examples of 2-year and 5-year late AMD prediction task. Each CFP in individual’s history is labeled 1 if late AMD onset is detected within a given prediction window (2 years and 5 years), otherwise 0. The blue lines represents the observation interval (t0, tl], the red line represents the prediction window (tl, tl + n]. A. Prediction scenario for 2 and 5 years. B. Training scenario for a single unrolled patient.
2.4. CNN-Transformer
CNN-Transformer uses multi-layer bidirectional Transformer encoder based on the original implementation described in Vaswani et al. [20] (Fig 1B). Transformer encoder has the advantage of processing the sequence as a whole instead of processing recursively. We compute one Transformer encoder layer as:
where , , and are the trainable parameters and dk is the dimension of K. We denote the input as , where F is the concatenation of all CFP features in the observation window of an individual and P is positional encoding. We set the number of Transformer encoder layers to 2 and the number of heads h to 8. We used the same sinusoidal function to generate positional encoding as used in [20].
2.5. Baselines
CNN-Single.
In the first baseline, we used two fully connected layers on top of the extracted features by using the pre-trained ResNet: the first layer is size 256 with ReLU activation and the last layer is size 1 with sigmoid activation for making prediction. CNN-Single takes the last feature of input sequence to predict 2-year and 5-year risk prediction. We refer this baseline model to CNN-Single since it only utilizes single CFP (the last CFP in input sequence).
We also reported the prediction performance of the model proposed by Yan et al. [22], where the authors used the Inception-v3 CNN architecture to extract image features from individuals’ latest visit’s CFP to predict whether the eye progression time to late AMD exceeded the specific time interval.
3. Experiment
3.1. Dataset
We used the data from AREDS, which was sponsored by the National Eye Institute of the National Institutes of Health. The data is publicly available upon request.1 It was a 12-year multi-center prospective cohort study of the clinical course, prognosis, and risk factors of AMD, as well as a phase III randomized clinical trial to assess the effects of nutritional supplements on AMD progression. The cohort includes 4,757 participants aged 55 to 80 years, who were recruited between 1992 and 1998 at 11 retinal specialty clinics in the United States. The inclusion criteria were wide, from no AMD in either eye to late AMD in one eye. All CFPs in the cohort were labeled with 0–12 scale severity score calculated by the reading center. Individuals having missing values in the severity score or no CFP were excluded, which resulted in 4,315 individuals.
For 2-year and 5-year late AMD risk prediction, we constructed two different datasets. We first removed recurring late AMD labels from all individuals in the cohort after the first late AMD onset was detected. This is due to the irreversible nature of AMD. An individual who developed late AMD will necessarily have recurring late AMD labels after the onset of late AMD, which could make a model biased and inflate prediction performance. The individuals having observation period less than 2 and 5 years were excluded from the 2-year and 5-year prediction dataset. Then, the 2-year prediction dataset was constructed by labeling each CFP to indicate whether late AMD onset was detected within 2-year prediction window. If there was no CFP within 2-year prediction window for labeling, we excluded the CFP. The 5-year prediction dataset was constructed in the same way except that 5-year prediction window was used. The characteristics of the entire cohort and two datasets are shown in Table 1.
Table 1.
Characteristics of the entire cohort, 2-year prediction dataset and 5-year prediction dataset. Median value (25 percentile, 75 percentile) was reported for length of observation and CFPs per eye.
| Entire cohort | 2-year | 5-year | |
|---|---|---|---|
| Individuals | 4,315 | 3,477 | 2,876 |
| Eyes | 8,630 | 7,661 | 6,258 |
| Eyes developed late AMD | 1,768 | 844 | 428 |
| Length of observation (year) | 8(6,11) | 9(6,11) | 10(7,11) |
| CFPs | 65,480 | 49,361 | 46,558 |
| CFPs labeled as late AMD | 8,422 | 1,240 | 1,573 |
| CFPs per eye | 8(5,10) | 7(3,9) | 8(5,10) |
3.2. Experiment Setup
ResNet-101 was pre-trained to extract features from the CFPs and then the extracted features were used as input for CNN-LSTM and CNN-Transformer for late AMD risk prediction. All CFPs were resized to 256×256 and then center cropped to 224×224. Training CFPs were randomly cropped, blurred, rotated, sheared and horizontally flipped for data augmentation. ResNet-101 was trained for 30 epochs with learning rate of 0.0005 and batch size of 32. We observed that using deeper ResNet architecture than ResNet-101 did not improve the performance. CNN-LSTM and CNN-Transformer were optimized using Adam [12] with learning rate of 0.0002, batch size of 32, and epoch of 30. L2 regularization was applied to the last full-connected layer in all models to prevent overfitting. We also applied weights to the loss and the stratified mini-batch to mitigate label imbalance. All models were implemented by Tensorflow. The experiments were performed on a machine equipped with two Intel Xeon Silver 4110 CPUs and one NVIDIA RTX 2080 GPU.
We used 5-fold cross validation for evaluation and reported area under the receiver operating characteristic curve (AUC). All datasets were partitioned into training, validation, and test sets with a 3:1:1 ratio, at the participant level. This ensures that no participant was in more than one partition to avoid cross contamination between the training and test datasets. Since we observed severe imbalance between labels in the dataset (Sect. 3.1 and Table 1), which may cause instability during training, we stratified each batch during training maintaining the ratio of late AMD and non-late AMD label.
3.3. Results
Overall Prediction Performance.
Table 2 shows the overall 2-year and 5-year prediction performance of all models. CNN-LSTM achieved 0.883 and 0.879 AUC in predicting 2-year and 5-year risk of late AMD. CNN-Transformer achieved 0.879 and 0.873 AUC in predicting 2-year and 5-year risk of late AMD. Both CNN-LSTM and CNN-Transformer model outperformed the baseline model in predicting 2-year and 5-year risk of late AMD. This indicates that utilizing longitudinal CFPs is helpful for the risk prediction of late AMD.
Table 2.
AUC in predicting 2-year and 5-year risk of late AMD of all models. Data were reported as: average AUC based on 5-fold cross validation (standard deviation).
| 2-year prediction | 5-year prediction | |
|---|---|---|
| Yan et al. [22] | 0.810 (0.000) | 0.790 (0.000) |
| CNN-Single | 0.868 (0.012) | 0.862 (0.023) |
| CNN-LSTM | 0.883 (0.017) | 0.879 (0.020) |
| CNN-Transformer | 0.879 (0.013) | 0.873 (0.020) |
Prediction Performance Based on the Number of CFPs.
We evaluated models on the subsets of datasets that only include specific number of longitudinal CFPs to investigate if the number of longitudinal CFPs affects the prediction performance. We first selected individuals having at least 5 longitudinal CFPs from test set and then sliced the longitudinal CFPs of the individuals from the last visit to build subsets of having specific number of CFPs from 2 to 5. For example, length-2 subset only includes two longitudinal CFPs sliced from the last visit. Details of the subsets are described in supplementary material.
Table 3 shows the 2-year and 5-year predictive performance based on the number of CFPs. CNN-LSTM and CNN-Transformer all showed increasing prediction performance with more longitudinal CFPs. For CNN-LSTM, AUC improved from 0.867 with 2 longitudinal CFPs to 0.873 with 5 longitudinal CFPs in both 2-year prediction and 5-year prediction. For CNN-Transformer, AUC improved from 0.862 and 0.861 with 2 longitudinal CFPs to 0.866 and 0.868 with 5 longitudinal CFPs in 2-year and 5-year prediction, respectively. This indicates that more longitudinal CFPs over longer observation period is beneficial for the risk prediction of late AMD.
Table 3.
AUC in predicting 2-year and 5-year risk of late AMD of all models based on the number of CFPs. Data were reported as: average AUC based on 5-fold cross validation (standard deviation).
| # of CFPs |
2-year prediction |
5-year prediction |
||
|---|---|---|---|---|
| CNN-LSTM | CNN-Transformer | CNN-LSTM | CNN-Transformer | |
| 2 | 0.867 (0.028) | 0.862 (0.031) | 0.867 (0.028) | 0.861 (0.290) |
| 3 | 0.870 (0.029) | 0.862 (0.035) | 0.870 (0.031) | 0.863 (0.032) |
| 4 | 0.872 (0.029) | 0.864 (0.037) | 0.872 (0.031) | 0.866 (0.028) |
| 5 | 0.873 (0.029) | 0.866 (0.037) | 0.873 (0.030) | 0.868 (0.029) |
4. Conclusion
In this work, we proposed deep learning models that utilize longitudinal color fundus photographs in individuals’ histories to predict 2-year and 5-year risk of progression to late AMD. The two proposed models, CNN-LSTM and CNN-Transformer, used LSTM and Transformer encoder with CNN, respectively, and outperformed baseline models that can only utilize a single-visit CFP to predict the risk of progression to late AMD. The proposed models also showed increasing performance in predicting the risk of progression to late AMD with longitudinal CFPs over a longer period, indicating that deep learning models are effective in capturing sequential information from longitudinal CFPs in individuals’ histories. Future works include the application of survival analysis to the loss objective in the models, development of end-to-end models, and integrating demographic and genetic information to further improve predictive performance.
Supplementary Material
Acknowledgments.
This material is based upon work supported by the Intramural Research Programs of the National Library of Medicine and National Eye Institute at National Institutes of Health.
Footnotes
References
- 1.Chen P, Pan C: Diabetes classification model based on boosting algorithms. BMC Bioinformatics 19(1), 1–9 (2018). 10.1186/s12859-018-2090-9 [DOI] [PMC free article] [PubMed] [Google Scholar]
- 2.Congdon N, et al. : Causes and prevalence of visual impairment among adults in the United States. Arch. Ophthalmol. (Chicago, Ill.: 1960) 122(4), 477–485 (2004) [DOI] [PubMed] [Google Scholar]
- 3.Ferris FL, et al. : A simplified severity scale for age-related macular degeneration: AREDS report no. 18. Arch. Ophthalmol. (Chicago, Ill.: 1960) 123(11), 1570–1574 (2005) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 4.Ghahramani GC, et al. : Multi-task deep learning-based survival analysis on the prognosis of late AMD using the longitudinal data in AREDS. medRxiv (2021) [PMC free article] [PubMed]
- 5.Graham KW, Chakravarthy U, Hogg RE, Muldrew KA, Young IS, Kee F: Identifying features of early and late age-related macular degeneration: a comparison of multicolor versus traditional color fundus photography. Retina 38(9), 1751–1758 (2018) [DOI] [PubMed] [Google Scholar]
- 6.Grassmann F, et al. : A deep learning algorithm for prediction of age-related eye disease study severity scale for age-related macular degeneration from color fundus photography. Ophthalmology 125(9), 1410–1420 (2018) [DOI] [PubMed] [Google Scholar]
- 7.Age-Related Eye Disease Study Research Group.: The age-related eye disease study (AREDS): design implications AREDS report no. 1. Control. Clin. Trials 20(6), 573 (1999) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 8.Hao S, et al. : Comparison of machine learning tools for the prediction of AMD based on genetic, age, and diabetes-related variables in the Chinese population. Regen. Ther 15, 180–186 (2020) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 9.He K, Zhang X, Ren S, Sun J: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016)
- 10.Hochreiter S, Schmidhuber J: Long short-term memory. Neural Comput 9(8), 1735–1780 (1997) [DOI] [PubMed] [Google Scholar]
- 11.Hu C, Steingrimsson JA: Personalized risk prediction in clinical oncology research: applications and practical issues using survival trees and random forests. J. Biopharm. Stat 28(2), 333–349 (2018) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 12.Kingma DP, Ba J: Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
- 13.Klein R: Overview of progress in the epidemiology of age-related macular degeneration. Ophthalmic Epidemiol 14(4), 184–187 (2007) [DOI] [PubMed] [Google Scholar]
- 14.Lorenzoni G, et al. : Comparison of machine learning techniques for prediction of hospitalization in heart failure patients. J. Clin. Med 8(9), 1298 (2019) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 15.Peng Y, et al. : Predicting risk of late age-related macular degeneration using deep learning. NPJ Digit. Med 3, 111 (2020). 10.1038/s41746-020-00317-z [DOI] [PMC free article] [PubMed] [Google Scholar]
- 16.Quartilho A, Simkiss P, Zekite A, Xing W, Wormald R, Bunce C: Leading causes of certifiable visual loss in England and wales during the year ending 31 March 2013. Eye 30(4), 602–607 (2016) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 17.Somasundaran S, Constable IJ, Mellough CB, Carvalho LS: Retinal pigment epithelium and age-related macular degeneration: a review of major disease mechanisms. Clin. Exp. Ophthalmol 48(8), 1043–1056 (2020) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 18.Stark K, et al. : The German AugUR study: study protocol of a prospective study to investigate chronic diseases in the elderly. BMC Geriatrics 15(1), 1–8 (2015). 10.1186/s12877-015-0122-0 [DOI] [PMC free article] [PubMed] [Google Scholar]
- 19.Sun W, Rumshisky A, Uzuner O: Annotating temporal information in clinical narratives. J. Biomed. Inform 46, S5–S12 (2013) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 20.Vaswani A, et al. : Attention is all you need. Adv. Neural Inf. Process. Syst 30 (2017) [Google Scholar]
- 21.Wong WL, et al. : Global prevalence of age-related macular degeneration and disease burden projection for 2020 and 2040: a systematic review and meta-analysis. The Lancet Global Health 2(2), e106–e116 (2014) [DOI] [PubMed] [Google Scholar]
- 22.Yan Q, et al. : Deep-learning-based prediction of late age-related macular degeneration progression. Nat. Mach. Intell 2(2), 141–150 (2020) [DOI] [PMC free article] [PubMed] [Google Scholar]
- 23.Yu B, et al. : SubMito-XGBoost: predicting protein submitochondrial localization by fusing multiple feature information and extreme gradient boosting. Bioinformatics 36(4), 1074–1081 (2020) [DOI] [PubMed] [Google Scholar]
Associated Data
This section collects any data citations, data availability statements, or supplementary materials included in this article.
