This section provides an overview of common FL algorithms that have been shown to work well outside of healthcare domain. Their performance was then evaluated on two machine learning tasks in the ICU, in-hospital mortality prediction and AKI prediction, using a dataset containing EHRs from multiple ICUs.
3.1 Overview of Common FL Algorithms
In general, FL involves each individual participants training local models on their local dataset alone and then exchanging model parameters, e.g., the weights and or gradients, at some frequency. There is no exchange of data among different participants. The local model parameters are then aggregated to generate a global model. Aggregation can be conducted with or without the coordination of a central party. Different FL algorithms vary in how the aggregation steps or the local update steps are performed. Among those, FedAvg [
50], is the most well-known. FedAvg aims to optimize the following objective:
where
\(N\) is the number of participants and
\(p_k\) is the weight of participant
\(k\) and
\(\sum _{i=k}^N p_k = 1\).
\(p_k\) is usually proportional to the size of each participant dataset.
\(F_k(\cdot)\) is the local objective function.
At each communication round
\(t\), a global model with weights
\(w_t\) is sent to all
\(K\) participants. Each participant
\(k\) performs local training for
\(E\) epochs, producing a new local model with weights
\(w^k_{t+1}\). Each participant then sends their newly learned local model weights to a central server where they are aggregated to obtain a new global model with updated weights
\(w_{t+1}\) equal to the weighted average of all local models:
FedAvg performs well in the case of homogeneity, where all local datasets are identically and independently distributed (IID). In the presence of statistical heterogeneity where data are not identically and independently distributed (non-IID) across participants, the global model might perform poorly or not even converge. A number of different approaches have been proposed to counter this problem and improve the convergence rate and performance of FL for non-IID datasets.
FedProx [
45] and SCAFFOLD [
38] aim to improve the convergence rate in FedAvg by correcting
client drift, a phenomenon where client heterogeneity causes a drift in the local updates in each round of local training, resulting in slow convergence. FedProx introduces a
proximal term that restricts the local updates to be closer to the latest global update. Instead of optimizing
\(F_k(\cdot)\), each participant now optimizes the local objective:
SCAFFOLD works by measuring the amount of drift caused by each client in each round and then adjusts their local update accordingly. How much a client drifts is measured by the difference in the direction of the global update versus the direction of the client local update.
Instead of controlling local training, several FL algorithms tackle the slow convergence problem by experimenting with server optimization. The global model update step specified in Equation (
2) can be rewritten as
where
\(\Delta w = \sum _{k=1}^K p_k\Delta w_k\) and
\(\Delta w_k\) is the weight updates from client
\(k\),
\(w = w - \Delta w\) has the same form as a gradient-based optimization step where
\(\Delta w\) acts as a
pseudo-gradient. Reddi et al. [
62] formalized this as a server optimization step that optimizes the model from a global perspective, in addition to the client optimization step
5 that aims to optimize the model from a local perspective. Their proposed FL algorithms FedAdagrad, FedAdam, and FedYogi employ adaptive server optimization by applying adaptive optimization methods Adagrad, Adam, and Yogi in the server optimization step. FedAvgM [
30,
31] is another algorithm that uses adaptive server optimization, by adding momentum to the server optimization step, computing
\(w = w - v\) where
\(v = \beta v + \Delta w\).
3.2 Experiments
We evaluated the performance of well-known FL algorithms, FedAvg, FedProx, FedAvgM, FedAdagrad, FedAdam, and FedYogi, on two common and clinically crucial machine learning tasks in the ICU, in-hospital mortality prediction and AKI prediction. Their results were compared against those obtained from local learning, centralized learning and two non-FL methods that also enable collaborative model training without data sharing, namely, IIL and
cyclic institutional incremental learning (CIIL) [
11,
67,
68]. In IIL, each party trains the model on their local dataset then passes the model to the next one until all parties have trained the model. CIIL repeats the same process over multiple rounds, but fixes the number of training epochs carried out by each party at each round. The data for both tasks come from the eICU dataset [
60], which collected EHRs from more than 200 hospitals and over 139,000 patients across the United States admitted to the ICU in 2014 and 2015. The dataset contains a wide range of data, including demographics, medication, diagnoses, procedures, timestamped vital signs, and lab test results.
For each task, several hospitals in the eICU database were selected as participants. The extracted data were split into a train, validation, and test set for each of the hospitals, each taking up 80%, 10%, and 10% of the whole population, respectively. In the local training setting, a separate model was trained for each hospital using only their own local data. The training was done over a number of epochs, and for each hospital, the model that gave the best performance on the validation set in terms of Area under the ROC Curve (AUC-ROC) became the final model for evaluation.
In the centralized setting, the train, validation, and test sets from all participating hospitals were concatenated to produce a single train, validation, and test set. A single model was then trained on the combined training set. Like in the local setting, training was conducted for several epochs and the best model was picked based on the AUC-ROC score on the combined validation set.
In the IIL and CIIL settings, since there was no global aggregated validation set due to no data sharing among participants, the model produced by the last party that conducted the training was selected as the final model.
In the FL setting, training was done over several communication rounds. Similar to the IIL and CIIL settings, since the central server that coordinated the training and carried out the global model aggregation process did not have access to a global validation set, the final model was the one obtained after all the communication rounds had finished.
Performance among the methods was compared based on global test scores. The metrics used are AUC-ROC and
Area under the Precision-Recall Curve (AUC-PR). Delong’s method [
15] and logit method [
9] were employed to compute 95% confidence intervals for AUC-ROC and AUC-PR, respectively.
3.2.1 In-hospital Mortality Prediction.
In this experiment, we investigated the performance of FL algorithms on predicting a patient’s in-hospital mortality based on data collected during the first 24 h of their ICU stay. This is a crucial task in clinical setting. When a patient is admitted to the ICU, predicting their mortality, either at the end of the ICU stay, hospital stay, or within a fixed period, e.g., 28 days, one month, or three months, provides a proxy for the severity of their condition and helps healthcare providers plan treatment pathways and allocate resources more effectively. There exist several works on successfully applying machine learning to predict in-hospital mortality [
6,
61,
82].
Data. The same data extraction process in References [
14,
37] was employed. For each hospital in the entire eICU dataset, we extracted a cohort of patients age 16 and above in their first ICU stay who had their in-hospital mortality status recorded. Patients without an APACHE IVa score were excluded. This criterion serves as a proxy for identifying patients with insufficient data or those who were only in the database for administration purpose. Twenty hospitals with the largest cohorts were then selected as participants in the study. The combined cohort contains 87,003 ICU stays.
For each patient, data within 24 h from ICU admission were extracted. The set of features includes
•
demographic information: gender, age, and ethnicity,
•
the first and last results of the following laboratory tests: PaO2, PaCO2, PaO2/FiO2 ratio, pH, base excess, Albumin, the significant band of arterial blood gas, HCO3, Bilirubin, Blood Urea Nitrogen (BUN), Calcium, Creatinine, Glucose, Hematocrit, Hemoglobin, international normalized ratio (INR), Lactate, Platelet, Potassium, Sodium, white blood cell count,
•
the first and last as well as the minimum and maximum measurements of the following vital signs: heart rate, systolic blood pressure, mean blood pressure, respiratory rate, temperature (Celcius), SpO2, Glasgow Coma Scale (GCS),
•
whether the hospital admission was for an elective surgery.
A total of 82 covariates were obtained.
Methods. A neural network consisting of two fully connected hidden layers with ReLU activation function and L2 normalization was used. The first hidden layer contains 100 nodes and the second 50. In the local and centralized settings, the model was trained for 90 epochs. In FL settings, the training took place over 30 communication rounds, with each hospital training the model locally for ten epochs each round.
3.2.2 AKI Prediction.
The purpose of this experiment was to evaluate the performance of FL algorithms on predicting the risk of a patient developing AKI within the next hour based on data collected during the previous 7 h. AKI is a sudden onset of renal damage or kidney failure that happens within a few hours or a few days and occurs in at least 5% of hospitalized patients [
16]. AKI can affect other organs such as lungs, heart, and brain. It significantly increases hospitalization cost as well as mortality risk [
13]. A timely detection of AKI could prevent patients from developing chronic kidney disease [
39,
71]. There have been several studies that show strong performance of machine learning models in predicting AKI [
26,
53,
58].
Data. We followed the same data extraction process in Reference [
16]. The RIFLE criteria [
7] were used to define AKI. Specifically, a patient at time
\(t\) will be labeled as suffering from AKI if their urine output is less than 0.5 ml/kg/h for
\(t\gt =6\). The cohort exclusion criteria include (1) patients who were under 16 years old or stayed in the ICU for less than 12 h and (2) patients whose data for the selected variables were not recorded at least once during their ICU stay. A total of 10,967 patients in 168 hospitals remained after the filtering. The top 75% hospitals with the most number of patients were selected to participate in the study. The final cohort contains 28 hospitals with a total of 6,641 patients.
For each patient, we extracted data in 7-h sliding windows. The full set of covariates includes
•
demographic information: age and gender,
•
the minimum and maximum values as well as the range (the difference between the maximum and minimum values) of the following vital signs: heart rate, respiratory rate, mean blood pressure,
•
the minimum and maximum values as well as the range of the following lab measurements: SpO2/SaO2, pH, Potassium, Calcium, Glucose, Sodium, HCO3, Hemoglobin, white blood cell count, Platelet count, Urea Nitrogen, Creatinine, GCS,
•
interventions: use of vasoactive medications, use of sedative medications, and use of mechanical ventilation.
A total of 22 covariates were obtained.
Methods. Similar to the previous task, a fully connected neural network consisting of two hidden layers with ReLU activation function and L2 normalization was used. However, here each of the two hidden layers contains 512 nodes instead of 100 and 50. In the local and centralized settings, the model was trained for 30 epochs. In FL settings, training took place over four communication rounds. Each hospital trained a local model for 10 local epochs during the first round and 5 local epochs during each subsequent round.
3.3 Results and Discussion
Global test performance in terms of AUC-ROC and AUC-PR obtained with each method is shown in Table
1 for in-hospital mortality prediction and Table
2 for AKI prediction. Comparison of ROC curves obtained with FL methods versus centralized and local training is visualized in Figures
1 and
2. Similarly, Figures
3 and
4 in Appendix
A show comparison of ROC curves obtained with FL methods compared to those obtained with CIIL and IIL. In both tasks, all FL methods outperform local training in either metric with the exception of FedProx in predicting AKI. In particular, for mortality prediction, all FL methods perform significantly better than local training. In comparison with IIL and CIIL, for mortality prediction, all FL methods achieve better results. For AKI prediction, the same is true for most FL methods. Only exceptions are FedProx, which obtains worse AUC-ROC and AUC-PR than both IIL and CIIL, and FedAdam, whose AUC-PR is slightly lower than that of CIIL. Overall, for both tasks, FL methods enjoy improvement over IIL and CIIL. This is unsurprising given that IIL is known to suffer from catastrophic forgetting [
23,
42,
68] while it is non-trivial to obtain optimal results with CIIL due to its instability [
68]. Results obtained by FL are also comparable to centralized learning, with the best FL method in each task achieving AUC-ROC within 0.01 of the global AUC-ROC for centralized learning in terms of point estimates. FedAvg and FedAvgM perform consistently well and are among the top three FL methods with the highest global AUC-ROCs and AUC-PRs in either task, only behind FedYogi in AKI prediction. In both cases, FedAvgM obtained slightly better results than FedAvg. However, FedProx achieved the lowest scores in both mortality and AKI prediction.
Results strongly favor FL as a viable strategy for facilitating collaboration among organizations in clinical research. Even though performance does not vary much among the different FL methods in our experiments, it is observed that simple FL algorithms, namely, FedAvg and FedAvgM, perform slightly better than FedProx, FedAdam and FedAdagrad. It has been shown that FedProx works well in the presence of heavy data heterogeneity [
45]. In our dataset, all hospitals are located in the United States and therefore expected to experience consistencies in clinical practices and patient demographics. Furthermore, they all participated in the Philips eICU program, which guarantees a certain degree of data standardization. Thus, the differences in data distribution among them are not significant enough to benefit from FedProx. Plus, the total number of participants is relatively small compared to FL in an IoT setting with a large number of participating devices where FedProx usually shines [
45]. Data homogeneity might also contribute to the lack of performance gain in FedAdam and FedAdagrad compared to FedAvg. In addition, both the tasks of predicting mortality and predicting AKI, similar to most machine learning tasks on tabular EHR data, only require the use of feed forward fully connected neural networks with a small number of layers, which might not see considerable performance gain through the use of Adam and Adagrad.