Skip to main content
UKPMC Funders Author Manuscripts logoLink to UKPMC Funders Author Manuscripts
. Author manuscript; available in PMC: 2024 Feb 1.
Published in final edited form as: IEEE J Biomed Health Inform. 2022 Apr 14;26(4):1761–1772. doi: 10.1109/JBHI.2021.3134835

Dynamic Neural Graphs Based Federated Reptile for Semi-supervised Multi-Tasking in Healthcare Applications

Anshul Thakur 1, Pulkit Sharma 1, David Clifton 1
PMCID: PMC7615588  EMSID: EMS183748  PMID: 34898443

Abstract

AI healthcare applications rely on sensitive electronic healthcare records (EHRs) that are scarcely labelled and are often distributed across a network of the symbiont institutions. It is challenging to train the effective machine learning models on such data. In this work, we propose dynamic neural graphs based federated learning framework to address these challenges. The proposed framework extends Reptile, a model agnostic meta-learning (MAML) algorithm, to a federated setting. However, unlike the existing MAML algorithms, this paper proposes a dynamic variant of neural graph learning (NGL) to incorporate unlabelled examples in the supervised training setup. Dynamic NGL computes a meta-learning update by performing supervised learning on a labelled training example while performing metric learning on its labelled or unlabelled neighbourhood. This neighbourhood of a labelled example is established dynamically using local graphs built over the batches of training examples. Each local graph is constructed by comparing the similarity between embedding generated by the current state of the model. The introduction of metric learning on the neighbourhood makes this framework semi-supervised in nature. The experimental results on the publicly available MIMIC-III dataset highlight the effectiveness of the proposed framework for both single and multi-task settings under data decentralisation constraints and limited supervision.

Index Terms: federated learning, multi-task learning, semi-supervised learning

I. Introduction

With the advent of informatisation of medical institutions, there is an explosion in the availability of digitised healthcare data such as genetic data, electronic healthcare records (EHRs), and medical research data [1]. Machine learning (ML) models can be trained with this healthcare data to perform various tasks such as developing precision medicine, assisting medical practitioners in diagnosis and predicting physiological deterioration during critical care [2]. Generally, healthcare data is distributed among a network of symbiont institutions. The utilisation of all of this distributed data is essential to train a reliable and effective ML model. However, the digitised healthcare data often contain private information, and the leakage of this data may compromise the patients’ privacy. To avoid such scenarios, strict standards such as the General Data Protection Regulation (GDPR)1 and the Data Protection Act (DPA)2 are in place to restrict access to sensitive healthcare records. As a result, collecting the distributed healthcare data at a third-party-owned centralised location to train ML models is not always feasible. Hence, ML applications are required to strike an equilibrium between data privacy and data analysis.

A large amount of healthcare data is generated every day but obtaining labels for this rapidly generating data is either unfeasible or highly resource-intensive. Hence, the amount of available labelled training data is significantly less than the unlabelled data. As a result, the unsupervised and semisupervised ML methods have garnered significant interest for many healthcare applications [3]–[6]. The unsupervised methods mostly include representation learning [7] and clustering [8] to identify the latent patterns, whereas semi-supervised methods incorporate both labelled and unlabelled data in the learning process. Previous studies have shown that semisupervised methods help in learning more robust and generalised models when the labelled data is scarce [9], [10]. Thus, there is a requirement of new techniques that incorporate the unlabelled data in the training process to complement supervised learning.

In this work, we concentrate on EHRs and machine learning (ML) tasks associated with critical care. Generally, these EHRs reflect a common set of features that are essential to perform multiple critical care tasks [11]. Hence, it makes sense to train a single model performing multiple tasks. By considering the nature of EHRs and the other challenges, we aim to propose a ML framework that has the following characteristics:

  • Data distribution: The framework must be able to alleviate challenges associated with the decentralised nature of training data. EHRs are sensitive and must not be shared with the other symbiont institutions (clients) or the third-party servers.

  • Semi-supervision: The framework must be able to exploit the vast amount of unsupervised data to learn better models.

  • Multi-tasking: The framework must be able to train models for multiple tasks simultaneously. In comparison to the multiple single-task models, the concurrent training of multiple tasks leads to a reduction in client-server communication. The decreased client-server communication also lowers the chances of data-leakage or membership inference attacks.

To address the aforementioned aspects of healthcare applications, we propose a federated learning (FL) framework [2], [12]. This work extends Reptile [13], a first-order model agnostic meta-learning (MAML) algorithm, to the federated setting such that deep learning models can be trained effectively on the distributed data. The security protocols such as secure multi-party computation [14] or differential privacy [15] can easily be incorporated in the proposed framework. Moreover, Reptile allows the proposed framework to train a shared global model targeting multiple similar tasks. The latent representations learned by the model for a particular task may also be relevant for other tasks. Hence, shared training may help in improving generalisation across all tasks. Apart from that, the parameters of the trained global model provide an informed parameter initialisation for training models on new unseen but similar tasks. In comparison to the random initialisation, this informed initialisation may result in faster convergence and better performance.

Reptile and other MAML algorithms [16] are designed for few-shot learning, and primarily focus on obtaining a shared model that can be adapted to new tasks with a few gradient updates. However, in most of the healthcare application, a significant amount of data (though unlabelled) is available for training. As a result, few-shot learning and fast-adaptation to new similar tasks may not be a priority. Instead, we are interested in obtaining the shared model that can exploit the common meaningful representations across multiple tasks to improve performance. Hence, the proposed framework utilises Reptile to learn this common representation across multiple similar tasks. The faster adaptation to new tasks is just a beneficial side-effect of Reptile.

As discussed earlier, we are interested in obtaining the shared model while effectively utilising the unlabelled data to aid supervised learning across multiple tasks. To achieve this goal, we propose a new variant of neural graph learning (NGL) [17]. In the existing formulation of NGL, the labelled examples are accompanied by their neighbours (labelled or unlabelled), defined by an input graph, to regularise the training process. Along with supervised training, NGL also tries to minimise the deviation between an example and its neighbours in an embedding space. Hence, NGL enforces the same semantic meaning i.e. class label on the entire neighbourhood. This behaviour is useful if the neighbourhood examples are unlabelled as semantic meaning is propagated from labelled example to its neighbouring examples. Though NGL is an effective semi-supervised learning method, it requires a synthesised graph as an input to the training algorithm. Generally, graphs are synthesised by measuring the similarity between embedding (of both labelled and unlabelled training examples) learnt using an unsupervised representation method such as auto-encoders (AE) [17]. Hence, training AE is an overhead and in case of decentralised data, a federated learning algorithm such as federated averaging [18] is required to train them. Besides the training of the main model (associated with the target tasks), training an AE further increases the threat of membership inference and other adversarial attacks by exposing more information (gradient updates or entire model states) over the client-server communication channels [19]. Hence, the requirement of input graphs or the training of representation models makes NGL undesirable in a decentralised or federated setup.

In this work, we address the shortcomings of NGL by proposing its dynamic variant that is more suited to federated setup. The proposed formulation of NGL does not require any synthetic input graph, and dynamically creates a local graph over a batch of training examples. The embedding used to establish the similarity between the input examples are obtained from an intermediate layer of the main model itself. Since the main model is guided by supervised learning, these embeddings are semantically meaningful and can help in creating effective local graphs or neighbourhoods. The local graphs are created at each client, and this whole dynamic setup does not result in any increment in the client-server communication while reaping the benefits of NGL. More details about the proposed dynamic NGL are in Section III. The proposed framework employs dynamic NGL in a taskspecific manner to compute meta-updates, to train the shared model at the server.

The major contributions of this paper are listed below:

  • This paper introduces the dynamic NGL-based federated framework for semi-supervised multi-tasking. Though the existing federated learning methods have explored semi-supervision, the semi-supervised learning in MAML algorithms is rare. To the best of our knowledge, this is one of the few studies that directly targets to improve the classification performance of a MAML algorithm by exploiting the unlabelled data.

  • This paper introduces a dynamic variant of NGL that does not require any input graph and can be used in both centralised as well as federated setups for semi-supervised learning.

  • This paper addresses the task heterogeneity associated with multi-tasking by utilising either the global or the local taskspecific layers (see Section III for details).

  • The experimental evaluation on the publicly available MIMIC-III dataset shows that the proposed dynamic NGL based federated framework exhibits either better or comparable performance against state-of-the-art semi-supervised federated learning frameworks.

The rest of this paper is organised as follows: Section II discusses the existing federated meta-learning and semisupervised federated learning frameworks. In Section III, we describe the proposed framework. Experimental setup and results are discussed in Section IV and V, respectively. Finally, Section VI concludes this paper.

II. Related Studies

In this section, we discuss the existing studies related to federated and semi-supervised federated learning frameworks. We also analyse the major differences between the proposed framework and the existing methods.

A. Federated learning

Recently, federated deep learning has been explored for many healthcare applications, such as diagnosing Parkinson’s disease [20], predicting adverse drug reactions [21] and early stroke prediction [22]. FL involves training a global ML model, using data distributed across different clients, in a decentralised manner. Most of the existing studies utilise federated averaging (FedAvg) algorithm [18] to train deep learning models in a distributed or federated setting. In FedAvg, during each round of training, the server asks the clients to train the local models using the corresponding local data. These local models are then averaged by the server to obtain a global model.

To address the statistical heterogeneity of the distributed data, Arivazhagan et al. [23] proposed to divide the model as the shared and the personalisation layers. The shared layers are global and are trained using FedAvg. On the contrary, personalisation layers are client-specific and are trained on the local data. This framework alleviates the requirement of explicit personalisation as the personalisation layers at each client make the trained models sensitive to the local data.

B. Federated meta-learning

MAML [16] and its first-order variants such as Reptile [13] can be interpreted as centralised case of FedAvg [24]. In MAML, task-specific models are used to compute metaupdates to train a shared global model. This behaviour of MAML is analogous to FedAvg as the local models in FedAvg are equivalent to the task-specific models of MAML. Hence, MAML algorithms are naturally suited for FL. Chen et al. [25] extended MAML to a federated setting to train a model over distributed data. The experimentation in this study showed that the federated MAML could achieve better performance than FedAvg. Apart from training the global model, meta-learning has also been used for user personalisation. Since MAML can adapt a model using a few gradient updates, it can be used for rapid personalisation of a global model. Jiang et al. [24] exhibited this behaviour by successfully personalising a FedAvg trained global model using Reptile.

C. Semi-supervised federated learning

Semi-supervised FL methods consider one of the following two scenarios [26], [27]:

  • Data at the server: In this scenario, a few labelled examples are available at the third-party-owned server whereas the clients only have unlabelled examples. It is a common use-case of semi-supervised FL as it is unrealistic in most of the cases to expect a client to have the labelled data.

  • No data at the server: In this scenario, the server has no data and is mainly associated with managing the training of the global model. On the other hand, both unlabelled and labelled examples are available at the clients. This use-case is mainly suited for healthcare application where the semi-supervised FL has to deal with the sensitive healthcare data such as EHRs. The access to this data is restricted by strict government guidelines and any data leakage makes the clients (medical institutions) liable to lawsuits. Hence, it is reasonable for the third-party-owned server to not have any access to the raw data. Here, the clients or medical institutions are expected to provide some labelled examples for the training. In this work, we are only interested in “no data at the server” and all further discussions are centred around this scenario.

In practice, any state-of-the-art semi-supervised method [28]–[30] can be used at clients for training the local models. In comparison to the supervised training, these methods can improve the performance of local models and hence, the global model. Two prominent semi-supervised methods i.e. Yalniz et al. [28] and Xie et al. [29] follow the student-teacher framework. Broadly, in these frameworks, the teacher is trained with supervised data and is used to generate pseudo-labels for the unlabelled examples. These pseudo-labels are then utilised by the student to aid the learning performed with the labelled examples. In addition to these approaches, a few studies have specifically exploited the federated setting to propose semi-supervised methods. One such prominent framework is proposed by Jeong et al. [27]. This method introduced interclient consistency to generate pseudo-labels and to induce regularisation in the training process using the unlabelled examples. The inter-client consistency loss is a combination of cross-entropy loss (computed using pseudo-labels) and KL-divergence among outputs generated by different client models for an unlabelled example. The server stores the local models (sent by clients) and forwards a subset of these stored models (called helper models) to a client along with the global model parameters. The ensemble of these helper models is used to obtain pseudo-labels for each unlabelled example and enforcing the inter-client inconsistency.

D. Comparison with the proposed framework

The proposed framework is similar to federated MAML [25] as it also performs model-agnostic meta-learning on the distributed data. However, it differs from federated MAML in two aspects: 1) federated MAML is completely supervised, whereas, the utilisation of dynamic NGL allows the proposed framework to incorporate the unsupervised examples in the learning process. 2). The proposed framework utilises Reptile for meta-learning. In comparison to MAML, the classification performance of Reptile is less sensitive to batch sampling, batch size and the number of training iterations at the clients [13]. As a result, Reptile allows the proposed framework to utilise more local training iterations at each client without worrying about batch sampling. More local iterations may result in decreased client-server communication by achieving near-optimum performance in fewer communication rounds. Apart from that, most of the aforementioned semi-supervised methods are not designed or evaluated for multi-tasking. In contrast to these method, the proposed framework directly targets the multi-task learning.

III. Dynamic NGL Based Federated Reptile

This section elaborates the proposed dynamic NGL based federated Reptile (NGL-FedRep) framework. Here, we first describe the problem statement. Then, we discuss the dynamic NGL. Finally, the overall framework is presented.

A. Problem statement

The healthcare data is distributed among the multiple medical institutions that act as clients or nodes. At each client, the local data can be a mixture of both labelled and unlabelled examples. A client is not allowed to share the data with the other clients or with the server to preserve the data privacy. The aim is to process this distributed data to learn a shared global model at the server, which can perform either a single task or multiple related tasks. Fig. 1 illustrates the problem statement graphically.

Fig. 1. An illustration of client-server interaction in the proposed framework.

Fig. 1

Unlike the other common use-cases of FL where clients are the handheld devices [24], the medical institutions (in a symbiont network) are usually significantly fewer in number, less-prone to drop-off and have adequate computational resources.

B. Dynamic neural graph learning (NGL)

Dynamic NGL allows the proposed framework to incorporate unlabelled examples to improve supervised learning. In the existing formulation of NGL [17], an undirected weighted graph is given as input (along with the training examples) to establish the semantic relations among the training examples. These semantic relations are used to define the neighbourhood around each training example. The training examples and their neighbourhood are used by NGL training objective to update the model. However, as discussed in Section I, there is no input graph in the proposed dynamic formulation of NGL and semantic structure among training examples is established during training. In this section, we first describe the process of creating local dynamic graphs. Later, we present the NGL loss function.

1). Dynamic graphs

Deep learning models tasked with supervised classification perform semantic clustering in an embedding space obtained after the models’ penultimate layer (or the last few layers). As the training progress, these clusters become more and more mutually exclusive. This implies that deep learning models learn (or try to learn) a transformation where the class-specific training examples are linearly separable. This behaviour is shown in Fig. 2 that illustrates the 2-d t-SNE representation of embedding generated by the penultimate layer of a CNN3 being trained over the MNIST dataset [31]. The analysis of this figure shows that the class-specific clusters become more evident with the training progression.

Fig. 2.

Fig. 2

2-d t-SNE representation of embedding generated by penultimate layer of a CNN being trained on the MNIST dataset. As the training progresses, the semantic clusters become more evident.

Semantic clustering exhibited by deep learning models provides a convenient way to create graphs that are required for neural graph learning. To perform stochastic gradient descent (SGD) based training, the available training examples are sampled into batches having b labelled examples. If we are dealing with semi-supervision, each batch is also augmented with c unlabelled examples. Before each round or epoch of training, a local dynamic graph is created for each batch as follows:

  • Let f1() represents the first l layers of the model f(). The current state of fl() (or f()) is used to obtain embedding for all examples of a batch.

  • Based on cosine similarity between embedding, a local graph is created where each example forms a vertex. Undirected weighted edges connect these vertices or examples if cosine similarity between them is greater than a pre-defined threshold. The cosine similarity is used as the weight on each edge. Here, the choice of cosine similarity is dictated by its fixed range, i.e. [−1, 1] that makes the thresholding trivial.

Since we are creating a new local graph for a batch before each epoch, these graphs are regarded as dynamic in nature. As the training progresses, the embedding generated by fl() becomes more meaningful, and hence, the semantic structure captured by the local graphs also becomes more accurate.

We analysed the neighbourhood accuracy of local graphs during each epoch to illustrate this behaviour. The neighbourhood accuracy is the ratio of edges among same class examples to the total edges in a local graph. Again, we use a CNN being trained on MNIST dataset to highlight this behaviour. The training examples are sampled into batches of 64 examples where 50% of examples are considered unlabelled. The model is trained using the NGL loss function (discussed later). A threshold of 0.9 is used to create local graphs. Fig. 3 depicts the average neighbourhood accuracy observed during the first 10 epochs of training. The analysis of this figure highlights that as the training progress, the average neighbourhood accuracy also increases signifying the improvement in semantic meaningfulness of the local graphs and embedding generated by CNN.

Fig. 3. Average neighbourhood accuracy observed by a CNN on the MNIST dataset during first 10 epochs of training.

Fig. 3

2). NGL loss function

Let G(V,E) be a local graph that defines the edge weights as similarities among examples (labelled or unlabelled) present in the associated batch. Each example, x, is represented by a vertex, vxV, in G(V,E). Using the graph, the neighbourhood of a given example, x, can be defined as:

Nx={vzV:e(vx,vz)E}, (1)

where e(vx, vz) represents an undirected edge connecting two vertices vx and vz.

Given a labelled example x (with label y) and its neigh-bourhood Nx defined in G(V,E), the NGL loss function can be defined as:

LNGL(y,y^,Nx)=LCE(y,y^)+αzNxwxzD(fl(x),fl(z)). (2)

Here ŷ = f(x) is the prediction generated by neural network f(). LCE represents the cross-entropy loss function4. D() represents the Euclidean distance, fl() represents the first l layers of neural network f(). wxz represents the weight of edge connecting vx and vz in G(V,E). α is a user-defined scalar to decide weightage of the second term in the loss function.

The NGL loss function, defined in equation 2, is a combination of both supervised loss (first term) and unsupervised loss (second term). By minimising the first term, NGL tries to assign the correct label to the labelled examples. On the other hand, the minimisation of the second term decreases the deviation among a labelled example and its neighbours in the embedding space. Fig. 4 illustrates this behaviour. Each labelled example acts as an anchor that attracts the neighbouring examples towards itself. Hence, the second term performs metric learning by enforcing the same latent representation on the entire neighbourhood. The same representation may result in similar predictions by neural network f(). These similar predictions support the hypothesis that the input and its neighbouring examples are highly likely to belong to the same class as they exhibit high similarity among themselves. This behaviour is of particular interest if the neighbourhood of a labelled example contains unlabelled examples.

Fig. 4.

Fig. 4

An illustration of the metric learning performed by the NGL loss function. α and ζ are the labelled examples that act as anchors and NGL tries to minimize the deviation among their respective neighbourhoods.

C. Proposed framework: Dynamic NGL based FedRep

Algorithm 1 documents the proposed framework (NGL-FedRep). It has two components: server and client-side processing, as described below:

1). Server-side processing

The server is concerned with initialising the shared global model and selecting clients for training. During each round of training, the server forwards the latest parameters of the shared global model to the selected clients. Each client performs the local training and computes the parameter updates for the global model. These parameter updates are forwarded to the server where they are aggregated and applied on the shared global model.

2). Client-processing

For task t, each client k possesses a labelled dataset DtL={(xi,yi)}i=1|DtL| and an unlabelled dataset DtU={xi}i=1|DtU|. During a round of training, the client k receives the global model from the server and performs the following operation:

  • A local model is initialised with the parameters or weights (θ) of the global model; let, Wt represents the parameters of the model for task t.

  • For each task t, a set of batches containing the labelled and unlabelled examples are sampled from DtL and DtU.

  • A dynamic graph is created for each training batch.

  • NGL loss is computed over the sampled batches and their corresponding graphs to obtain gradient updates for training Wt. The updated parameters are represented by Wt.

  • The updated task-specific models are used to compute the meta-updates for training the global model as:
    Φ=1Tt=1T(WtWt). (3)
    Here T is the total number of tasks. In the end, Φ is forwarded to the server, where it is used to update the global model.

NGL updates the initial parameter (θ or Wt) of a task-specific model to Wt by applying a series of gradient updates. Intuitively, we can represent this overall training as a single parameter update that is obtained using the gradient: θt=θWt. We can aggregate all T task-specific gradients to obtain meta-gradient as:

θ=1Tt=1T(θWt)=θ1Tt=1T(Wt) (4)

Hence, a meta-gradient appears to move the initial parameter (θ) in the direction of the average of the trained task-specific models. As a result, the framework is able to train over all the tasks simultaneously. A series of the meta-gradient based updates will move θ to a final configuration that is in proximity of the near-optimal parameters of each task-specific model. Since the global model parameters are near the optimal parameters of each task, only a few-gradient updates are required to obtain the optimal task-specific models.

D. Overcoming task heterogeneity

Task heterogeneity is a common problem in multi-tasking. The tasks that rely on similar latent representation may require a few task-specific layers. For example, in computer vision, scene classification and object detection may benefit from the same latent representations describing the objects. However, two CNNs performing these two tasks require different last layer. The proposed framework addresses such cases by considering the task-specific models as a combination of the common layers and the task-specific layers. The common layers are global and are shared across all tasks. These layers are trained at the server using meta-updates provided by clients. On the other hand, the task-specific layers can be trained in two ways:

  • Local training at clients: In this training mechanism, the task-specific layers are local and are trained at each client using the local data. During training, a client trains a task-specific model using NGL, and computes meta-updates for the global layers. The meta-updates are forwarded to the server, while the local layer parameters are stored at the client for the next round of training. Along with handling the difference in nature of tasks, the use of these local layers also help in personalising NGL-FedRep to the local data [23]. However, this mechanism is not able to fully utilise the data distributed across different clients.

  • Global training at server: In this mechanism, the task-specific layers are also shared between server and clients. During each round of training, the clients compute meta-updates for the common layers and the task-specific layers. The server aggregates these meta-updates and apply them to train the common and task-specific layers. This mechanism allows NGL-FedRep to fully utilise the distributed data. Moreover, the user personalization can be obtained by fine-tuning the trained task-specific model on the local data of a particular client.

IV. Experimental Setup

In this section, we describe the dataset, model architectures and experiments used for the performance evaluation of the proposed framework.

A. Dataset

The performance of the proposed framework is evaluated on the publicly available MIMIC-III dataset [32]. It is a large database that contains information on patients admitted to critical care units at a tertiary care hospital. Data includes vital signs, medications, laboratory measurements, fluid balance, hospital length of stay, survival data, and more. As described in [11], this data is pre-processed to create sub-datasets for four different tasks5. In each sub-dataset, an example is an evenly spaced time-series where different clinical measurements resulting in 76 features are sampled at each time-step. In this work, we have specified a time-step of one hour. The four tasks are described below:

  • In-hospital mortality prediction: This is a binary classification task that deals with predicting in-hospital mortality based on the first 48 hours of ICU stay. The corresponding sub-dataset contains 18, 342 negative (no mortality) and 2, 797 positive (mortality) examples.

  • Decompensation prediction: This is also a binary classification task that deals with predicting whether the patient’s health will deteriorate in the next 24 hours. The corresponding sub-dataset contains 3, 360, 926 negative (no deterioration observed) and 70, 696 positive (deterioration observed) examples.

  • Phenotype classification: This is a multi-label classification task that deals with identifying which acute care conditions

    Algorithm 1. Neural graph learning (NGL) based federated Reptile.

    1: // Run on server
    2: Server-Exc(N, η):η: learning rate, N: No. of rounds
    3: Initialise the global model with parameter θ0
    4: for i ← 1 : N do
    5: Select a set of K clients for training
    6: for each kK do ⊳ Clients execute in parallel
    7: Φk ←EXECUTE-CLIENT(θi-1)
    8: θiθi1+ηkK(Φk) ⊳ Updating parameter
    9: return θN
    10:
    11: // Run on clients
    12: Execute-Client(θ):
    13: DtL,DtU: labelled and unlabelled datasets for task t
    14: Let fθ() be the model initialised with θ
    15: for t ← 1 : T doT: Number of tasks
    16: Wt = θ
    17: Let ft() be a model with Wt parameters
    18: BSAMPLE-BATCHES(DtL,DtU)|B| batches, each with b labelled examples (having labels l) and c unlabelled examples
    19: for each (b,l,c)B do
    20: G(V,E)CREATE-GRAPH(b,c,fθ()) ⊳ As discussed in Section III-B.1
    21: NbGet-Neighbourhood(b,G(V,E)) ⊳ Neighbourhood of each of the labelled b examples as defined in equation 1
    22: L=LNGL(ft(b),l,Nb) ⊳ NGL loss
    23: Wt=Wtβ(WtL)β: learning rate
    24: Φ=1Tt=1T(WtWt) ⊳ Meta update
    25: return Φ
    such as acute cerebrovascular disease and acute renal failure are present in a given patient’s ICU stay record. There are 25 different conditions considered for this task, and their details can be found in [11]. The phenotype classification sub-dataset contains 41, 902 examples.
  • Length-of-stay prediction: This task deals with predicting the remaining length of stay in ICU at each hour. The remaining length of stay is quantified into ten classes such as less than a day, one to seven days of the first week, over one week but less than two, and over two weeks. Hence, it is regarded as a multi-class classification problem. The length-of-stay prediction sub-dataset contains 3, 451, 346 examples that are distributed unevenly among 10 classes.

B. Experiments

We designed three experiments to evaluate different aspects of the proposed framework:

  • FL for single task scenarios: The proposed framework (NGL-FedRep) is trained for the tasks of mortality prediction, decompensation prediction and phenotype classification separately. The performance of NGL-FedRep is compared against widely used federated averaging (FedAvg) [18] algorithm. Apart from that, the performance of both NGL-FedRep and FedAvg is compared to the centralised neural network training. The main purpose of this experiment is to analyse the impact of decentralisation of training data on classification performance.

  • FL for Multi-tasking: NGL-FedRep is trained for tasks of mortality prediction, decompensation prediction and phenotype classification simultaneously. The performance of each task is compared to the performance of the task-specific centralised neural networks. The performance of local and global mechanisms to train the task-specific layers is also compared.

    The parameters of the trained common layers are further used for initialising the length-of-stay prediction model. This model is trained in a centralised setup, and its performance is compared to a randomly initialised model. Note that in both scenarios, the task-specific layer is randomly initialised.

  • Semi-supervised FL: The datasets considered in this study are entirely labelled. Hence, to highlight the impact of dynamic NGL on the proposed framework, we only consider 10%, 25% and 50% of the available training examples as labelled, and the rest of examples are regarded as unlabelled. The performance of NGL-FedRep is compared to a non-graph version of the proposed framework (FedRep). This version only utilises the available labelled data, and act as a baseline for supervised training in data-scarce scenarios. Moreover, the performance of NGL-FedRep is compared with off-the-shelf semi-supervised methods such as Yalniz et al. [28] and Xie et al. [29]. These methods are incorporated in the proposed framework (instead of dynamic NGL) to train the local models at each client. Apart from that, the semi-supervised FL method proposed by Jeong et al. [27] (see Section II for details) is also used as a comparative method.

Note that NGL-FedRep with global layers (not the local layers) is used in this experiment.

C. Data distribution, models and parameter setting

  • Data distribution: For each task, the available data is distributed among 20 clients in a non-IID manner. At each client, 10% of the data is used for validation. On the remaining data, five-fold cross-validation is used to create five train-test datasets. In the semi-supervised experiment, only 10%, 25% and 50% of the total training examples at each client are considered as labelled. The testing is performed at a client and the predictions on all the test examples (across all the clients) are used to compute the performance metrics.

  • Models: The model architectures for each task are almost similar, and they only differ in the last dense layer. Fig. 5 illustrates the model architectures used in this study. The binary cross-entropy is used as the supervised loss function (first term in equation 2) for in-hospital mortality prediction, decompensation prediction and phenotype classification. Similarly, for length-of-stay prediction, categorical cross-entropy is used as the supervised loss function.

  • Parameter setting in NGL-FedRep: During each round of training, all 20 clients are selected for obtaining meta-updates. A fixed step-size, η = 0.15, is used to update the global parameters. At a client, the task-specific models are trained on every available example in each round. The training data is presented in batches of 16 examples. In the case of supervised training, the neighbourhood examples defined by the dynamic graphs are always labelled. In semi-supervised learning, each batch contains 8 labelled and 8 unlabelled examples. Also, the neighbourhood of a labelled example can contain both labelled and unlabelled examples. The parameter α = 0.2 and embedding generated after LSTM layer are used in dynamic graph creation and the second term of the NGL loss function (equation 2). For dynamic graphs, a fixed threshold of 0.9 is used on cosine similarity to create edges between two examples. At each client, Adam optimiser with a fixed learning rate of β = 0.001 is used to train all the local models. All these values are fine-tuned to provide maximum performance on the validation examples.

  • Parameter details in comparative methods: The parameter setting to train the local models such as optimizer and learning rate (discussed above) are the same in all the comparative methods. In Yalniz et al. [28], the same task-specific model architectures are used as the teacher and the student models. To implement Xie et al. [29], we used the model architectures shown in Fig. 5 as the noisy student. However, to implement the teacher models, the dropout layers were removed from the architectures. In Jeong et al. [27], we used three helper models to implement inter-client consistency. All the parameters used in these comparative studies are also fine-tuned on the validation examples.

Fig. 5. Model architectures used for different tasks.

Fig. 5

V. Experimental Results & Discussion

In this section, we present and discuss results obtained during single and multi-task experimentation. We also compare the performance of the proposed method with different comparative methods.

A. FL for single task scenarios

Fig. 6 depicts the average loss (across all clients) observed during training of the proposed framework (NGL-FedRep). This figure highlights that the average loss decreases during training, and NGL-FedRep can effectively train on the distributed data. Fig. 7 illustrates the performance of NGL-FedRep on the test examples. Following inference can be drawn from the analysis of this figure:

  • The best classification scores exhibited by NGL-FedRep are comparable to the performance of centralised neural networks on all three tasks. Along with overcoming the constraints of data distribution, NGL-FedRep also performs better than FedAvg on all three tasks.

  • During later rounds of training, FedAvg shows over-fitting for the decompensation prediction. In contrast, NGL-FedRep does not exhibit such behaviour. This can be attributed to the regularisation imposed by neural graph learning in NGL-FedRep.
    • Effect of the number of batches: As discussed earlier, at each client, we use all the available local batches to train a local model during each round of training. This is in contrast to existing MAML and federated MAML algorithms where only a few batches are used for training the local models at a time.
      To analyse the impact of the number of batches, we used 5, 10 and 20 batches for training the local models. Fig. 8 documents the performance of NGL-FedRep as a function of the number of batches. The analysis of this figure makes it clear that a large number of batches used for the local training results in better performance. Moreover, by comparing Fig. 7 and Fig. 8, it is clear that using all of the available local batches leads to near-optimum performance in only a few rounds of training. As discussed in Section II, this reduces the required number of interactions between clients and the server.

Fig. 6.

Fig. 6

Average loss observed across all clients during training of NGL-FedRep framework under single task scenario. The variance represents the average training loss variations observed across all folds.

Fig. 7.

Fig. 7

Performance of NGL-FedRep trained models in single task scenario.The average performance (across five-folds) during each round of training is presented here.

Fig. 8. Effect of decreasing the number of batches for updating local models at clients during each round of training.

Fig. 8

B. FL for Multi-tasking

Fig. 9 shows the performance of NGL-FedRep trained for mortality prediction, decompensation prediction and pheno-typing simultaneously. Analysis of this figure highlights that NGL-FedRep with global task-specific layers (NGL-FedRep Global) performs comparably to the centralised task-specific neural networks on all three tasks. On the other hand, NGL-FedRep with local or personalised task-specific layers (NGL-FedRep Local) exhibits comparable performance for the task of decompensation prediction but shows lower classification scores than NGL-FedRep Global for mortality prediction and phenotyping. At each client, the available training data for mortality prediction and phenotyping is significantly less than the data available for decompensation prediction. This may have hindered the effective training of the task-specific layers in NGL-FedRep Local.

Fig. 9.

Fig. 9

Performance of NGL-FedRep in multi-tasking scenario. The average performance across five-folds is presented here.

Fig. 10 depicts the performance for the task of length-of-stay prediction. The model is trained using dynamic NGL in a centralised setup. Before training, the global layer i.e. LSTM layer of the model (Fig. 5) is initialised using parameters obtained from the trained NGL-FedRep Global. This initialisation is referred to as “informed intialisation” because the trained parameters encapsulate the useful latent representations learned from similar tasks. The analysis of Fig. 10 shows that informed initialisation results in better performance than the random initialisation.

Fig. 10.

Fig. 10

Effect of initialisation schemes on the length-of-stay prediction. For each scheme, five Kappa scores represent the performance across five folds.

C. Semi-supervised FL

Fig. 11 (a), (b) and (c) illustrate the bar-plots signifying the performance of different comparative methods at different levels of supervision for the tasks of mortality prediction, decompensation prediction and phenotyping respectively. The models for these three tasks are trained simultaneously in a multi-tasking scenario. Following inference can be drawn from the analysis of Fig. 11:

  • In all cases, the semi-supervised methods show improvement over the supervised framework (FedRep). In particular, NGL-FedRep outperform FedRep by a noticeable margin in all cases. This indeed highlights that the proposed dynamic NGL can exploit unlabelled examples to improve the supervised training.

  • At 10% and 25% supervision, NGL-FedRep and Jeong et al. [27] show a noticeable improvement over Yalniz et al. [28] and Xie et al. [29] in the classification performance. However, at 50% supervision, their performance is comparable to NGL-FedRep and Jeong et al. [27].

  • The performance of NGL-FedRep and Jeong et al. [27] is comparable in most cases. However, Jeong et al. [27] exhibit slightly better performance than NGL-FedRep at 10% supervision. NGL-FedRep relies on the semantics of the embedding to create meaningful dynamic graphs. At very low supervision, these semantics and graphs may not be as meaningful as at the higher level of supervision. This may have resulted in a slight decrement in the performance of NGL-FedRep at 10% supervision.

  • Although performance of Jeong et al. [27] and NGL-FedRep is comparable, the client-server communication required in Jeong et al. [27] is four times more than NGL-FedRep. For each task, the server sends three helper models (user-defined) along with the global model parameters to each client. As discussed in Section I, more communication on the client-server channel may increase vulnerability to any adversarial or membership inference attacks.

Fig. 11.

Fig. 11

Performance of NGL-FedRep and other semi-supervised methods at different levels of supervision. The average performance across five-folds is presented here.

VI. Conclusions

In this paper, we presented a federated model agnostic meta-learning framework for multi-tasking in healthcare applications. We exploited meta-learning to learn a common representation across three different critical care tasks to perform effective multi-tasking. To perform semi-supervision, we proposed a new dynamic variant of neural graph learning (NGL) that does not require any input graph and can effectively utilise the unlabelled data to aid the supervised learning. The experimental results on MIMIC-III showed that the proposed framework is capable of overcoming the constraints imposed by data decentralisation and limited supervision to exhibit a respectable classification performance. Future work may involve incorporating the privacy-preserving mechanisms such as secure multi-party computation and differential privacy in the proposed framework.

Footnotes

3

CNN consists of two conv layers and a dense layer. The architecture and training details are available at: keras.io/examples/vision/mnist_convnet/

4

LCE can be replaced with any other loss function.

5

Benchmarking code available at https://github.com/YerevaNN/mimic3-benchmarks is used.

Contributor Information

Anshul Thakur, Email: anshul.thakur@eng.ox.ac.uk.

Pulkit Sharma, Email: pulkit.sharma@eng.ox.ac.uk.

David Clifton, Email: david.clifton@eng.ox.ac.uk.

References

  • [1].Groves P, Kayyali B, Knott D, Kuiken SV. The ‘big data’ revolution in healthcare: Accelerating value and innovation. 2016. [Online]. Available: repositorio.colciencias.gov.co/handle/11146/465.
  • [2].Xu J, Wang F. Federated learning for healthcare informatics. arXiv preprint arXiv. 2019:1911.06270 [Google Scholar]
  • [3].Ma F, Meng C, Xiao H, Li Q, Gao J, Su L, Zhang A. Unsupervised discovery of drug side-effects from heterogeneous data sources; Proceedings of International Conference on Knowledge Discovery and Data Mining; 2017. pp. 967–976. [Google Scholar]
  • [4].Liu C, Wang F, Hu J, Xiong H. Temporal phenotyping from longitudinal electronic health records: A graph based framework; Proceedings of International Conference on Knowledge Discovery and Data Mining; 2015. pp. 705–714. [Google Scholar]
  • [5].Yuan C, Wang Y, Shang N, Li Z, Zhao R, Weng C. A graphbased method for reconstructing entities from coordination ellipsis in medical text. Journal of the American Medical Informatics Association. 2020 doi: 10.1093/jamia/ocaa109. [Online] [DOI] [PMC free article] [PubMed] [Google Scholar]
  • [6].Liu M, Zhou M, Zhang T, Xiong N. Semi-supervised learning quantization algorithm with deep features for motor imagery eeg recognition in smart healthcare application. Applied Soft Computing. 2020;89:106071 [Google Scholar]
  • [7].Che Z, Kale D, Li W, Bahadori MT, Liu Y. Deep computational phenotyping; Proceedings of International Conference on Knowledge Discovery and Data Mining; 2015. pp. 507–516. [Google Scholar]
  • [8].Delias P, Doumpos M, Grigoroudis E, Manolitzas P, Matsatsinis N. Supporting healthcare management decisions via robust clustering of event logs. Knowledge-Based Systems. 2015;84:203–213. [Google Scholar]
  • [9].Zhang P, Wang F, Hu J, Sorrentino R. Label propagation prediction of drug-drug interactions based on clinical side effects. Scientific reports. 2015;5(1):1–10. doi: 10.1038/srep12339. [DOI] [PMC free article] [PubMed] [Google Scholar]
  • [10].Wang Z, Shah AD, Tate AR, Denaxas S, Shawe-Taylor J, Hemingway H. Extracting diagnoses and investigation results from unstructured text in electronic health records by semi-supervised machine learning. PLoS One. 2012;7(1):e30412. doi: 10.1371/journal.pone.0030412. [DOI] [PMC free article] [PubMed] [Google Scholar]
  • [11].Harutyunyan H, Khachatrian H, Kale DC, Steeg GV, Galstyan A. Multitask learning and benchmarking with clinical time series data. Scientific Data. 2019;6(96):1–18. doi: 10.1038/s41597-019-0103-9. [DOI] [PMC free article] [PubMed] [Google Scholar]
  • [12].Brisimi TS, Chen R, Mela T, Olshevsky A, Paschalidis IC, Shi W. Federated learning of predictive models from federated electronic health records. International Journal of Medical Informatics. 2018;112:59–67. doi: 10.1016/j.ijmedinf.2018.01.007. [DOI] [PMC free article] [PubMed] [Google Scholar]
  • [13].Nichol A, Schulman J. Reptile: a scalable metalearning algorithm. arXiv preprint arXiv. 2018;2(3):4.:1803.02999 [Google Scholar]
  • [14].Hardy S, Henecka W, Ivey-Law H, Nock R, Patrini G, Smith G, Thorne B. Private federated learning on vertically partitioned data via entity resolution and additively homomorphic encryption. arXiv preprint arXiv. 2017:1711.10677 [Google Scholar]
  • [15].Dwork C, Kenthapadi K, McSherry F, Mironov I, Naor M. Our data, ourselves: Privacy via distributed noise generation; Proceedings of International Conference on the Theory and Applications of Cryptographic Techniques; 2006. pp. 486–503. [Google Scholar]
  • [16].Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks; Proceedings of International Conference on Machine Learning; 2017. [Google Scholar]
  • [17].Bui TD, Ravi S, Ramavajjala V. Neural graph learning: Training neural networks using graphs; Proceedings of International Conference on Web Search and Data Mining; 2018. pp. 64–71. [Google Scholar]
  • [18].McMahan B, Moore E, Ramage D, Hampson S, Arcas BAy. Communication-efficient learning of deep networks from decentralized data. Proceedings of Artificial Intelligence and Statistics. 2017:1273–1282. [Google Scholar]
  • [19].Leino K, Fredrikson M. Stolen memories: Leveraging model memorization for calibrated white-box membership inference; USENIX Security Symposium; 2020. pp. 1605–1622. [Google Scholar]
  • [20].Chen Y, Qin X, Wang J, Yu C, Gao W. Fedhealth: A federated transfer learning framework for wearable healthcare. IEEE Intelligent Systems. 2020 [Google Scholar]
  • [21].Choudhury O, Park Y, Salonidis T, Gkoulalas-Divanis A, Sylla I, et al. Predicting adverse drug reactions on distributed health data using federated learning; Proceedings of AMIA Annual Symposium; 2019. pp. 313–322. [PMC free article] [PubMed] [Google Scholar]
  • [22].Ju C, Zhao R, Sun J, Wei X, Zhao B, Liu Y, Li H, Chen T, Zhang X, Gao D, et al. Privacy-preserving technology to help millions of people: Federated prediction model for stroke prevention. arXiv preprint arXiv. 2020:2006.10517 [Google Scholar]
  • [23].Arivazhagan MG, Aggarwal V, Singh AK, Choudhary S. Federated learning with personalization layers. arXiv preprint arXiv. 2019:1912.00818 [Google Scholar]
  • [24].Jiang Y, Konečný J, Rush K, Kannan S. Improving federated learning personalization via model agnostic meta learning. arXiv preprint arXiv. 2019:1909.12488 [Google Scholar]
  • [25].Chen F, Luo M, Dong Z, Li Z, He X. Federated meta-learning with fast convergence and efficient communication. arXiv preprint arXiv. 2018:1802.07876 [Google Scholar]
  • [26].Zhang Z, Yao Z, Yang Y, Yan Y, Gonzalez JE, Mahoney MW. Benchmarking semi-supervised federated learning. arXiv preprint arXiv. 2020:2008.11364 [Google Scholar]
  • [27].Jeong W, Yoon J, Yang E, Hwang SJ. Federated semisupervised learning with inter-client consistency. arXiv preprint arXiv. 2020:2006.12097 [Google Scholar]
  • [28].Yalniz IZ, Jégou H, Chen K, Paluri M, Mahajan D. Billionscale semi-supervised learning for image classification. arXiv preprint arXiv. 2019:1905.00546 [Google Scholar]
  • [29].Xie Q, Luong M-T, Hovy E, Le QV. Self-training with noisy student improves imagenet classification; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition; 2020. pp. 10687–10698. [Google Scholar]
  • [30].Xie Q, Dai Z, Hovy E, Luong M-T, Le QV. Unsupervised data augmentation for consistency training. arXiv preprint arXiv. 2019:1904.12848 [Google Scholar]
  • [31].Deng L. The MNIST database of handwritten digit images for machine learning research [best of the web] IEEE Signal Processing Magazine. 2012;29(6):141–142. [Google Scholar]
  • [32].Johnson AE, Pollard TJ, Shen L, Li-Wei HL, Feng M, Ghassemi M, Moody B, Szolovits P, Celi LA, Mark RG. Mimic-III, a freely accessible critical care database. Scientific data. 2016;3(1):1–9. doi: 10.1038/sdata.2016.35. [DOI] [PMC free article] [PubMed] [Google Scholar]

RESOURCES