Federated Variational Autoencoders in biomedical data

Here, an example of a variability analysis using a Multi-channel variational autoencoder proposed by Antelmi et al.

import copy
import pandas as pd

import torch
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
N_CENTERS = 4
N_ROUNDS = 50   # Number of iterations between all the centers training and the aggregation process.

N_EPOCHS = 100   # Number of epochs before aggregating
BATCH_SIZE = 48
LR = 1e-2       # Learning rate

We define a set of functions to distribute our dataset across multiple centers (split_iid) and for doing federated averaging (federated_averaging).

import numpy as np
def split_iid(df, n_centers):
    """ Split pandas DataFrame dataset randomly into n_centers """
    data = df.sample(frac=1)  # Shuffle dataset
    data =  np.array_split(data, n_centers)
    return data
def federated_averaging(models, n_obs_per_client):
    assert len(models) > 0, 'An empty list of models was passed.'
    assert len(n_obs_per_client) == len(models), 'List with number of observations must have ' \
                                                 'the same number of elements that list of models.'

    # Compute proportions
    n_obs = sum(n_obs_per_client)
    proportions = [n_k / n_obs for n_k in n_obs_per_client]

    # Empty model parameter dictionary
    avg_params = models[0].state_dict()
    for key, val in avg_params.items():
        avg_params[key] = torch.zeros_like(val)

    # Compute average
    for model, proportion in zip(models, proportions):
        for key in avg_params.keys():
            avg_params[key] += proportion * model.state_dict()[key]

    # Copy one of the models and load trained params
    avg_model = copy.deepcopy(models[0])
    avg_model.load_state_dict(avg_params)

    return avg_model

Federating dataset

csv = 'https://gitlab.inria.fr/ssilvari/flhd/-/raw/master/heterogeneous_data/pseudo_adni.csv?inline=false'
df = pd.read_csv(csv)

df.sample()
SEX AGE PTEDUCAT CDRSB.bl ADAS11.bl MMSE.bl RAVLT.immediate.bl RAVLT.learning.bl RAVLT.forgetting.bl FAQ.bl WholeBrain.bl Ventricles.bl Hippocampus.bl MidTemp.bl Entorhinal.bl APOE4 ABETA.MEDIAN.bl PTAU.MEDIAN.bl TAU.MEDIAN.bl
669 1.0 69.0 14.945415 0 7 27.0 28.482003 3.0 3.349521 0 0.687166 0.007506 0.005003 0.013367 0.002208 0 200.999667 9.887908 7.109941
df.columns
Index(['SEX', 'AGE', 'PTEDUCAT', 'CDRSB.bl', 'ADAS11.bl', 'MMSE.bl',
       'RAVLT.immediate.bl', 'RAVLT.learning.bl', 'RAVLT.forgetting.bl',
       'FAQ.bl', 'WholeBrain.bl', 'Ventricles.bl', 'Hippocampus.bl',
       'MidTemp.bl', 'Entorhinal.bl', 'APOE4', 'ABETA.MEDIAN.bl',
       'PTAU.MEDIAN.bl', 'TAU.MEDIAN.bl'],
      dtype='object')
df["RAVLT.learning.bl"].value_counts()
4.0     153
5.0     128
3.0     117
6.0     112
0.0     100
2.0      98
1.0      96
7.0      90
8.0      62
9.0      31
10.0     10
12.0      2
11.0      1
Name: RAVLT.learning.bl, dtype: int64

Now, federated_dataset is a list of subsets of the main dataset.

federated_dataset = split_iid(df, n_centers=N_CENTERS)
print('Number of centers:', len(federated_dataset))
Number of centers: 4

Finally, we must parse each dataframe in the form of a tensor Dataset grouping variables in 5 channels:

  • Volumetric data

  • Demographics

  • Cognition

  • Genetics: Apolipoprotein E (APOE)

  • Fluid biomarkers: Amyloid beta (Abeta) and Tau concentrations in the Cerebrospinal fluid (CSF).

def get_channels():
    
    channel_1 = ['WholeBrain.bl', 'Ventricles.bl', 'Hippocampus.bl', 'MidTemp.bl', 'Entorhinal.bl']
    channel_2 = ['SEX', 'AGE', 'PTEDUCAT']
    channel_3 = ['CDRSB.bl', 'ADAS11.bl', 'MMSE.bl', 'RAVLT.immediate.bl', 'RAVLT.learning.bl', 'RAVLT.forgetting.bl', 'FAQ.bl']
    channel_4 = ['APOE4']
    channel_5 = ['ABETA.MEDIAN.bl', 'PTAU.MEDIAN.bl', 'TAU.MEDIAN.bl']
    return channel_1, channel_2, channel_3, channel_4, channel_5

def get_data_as_multichannel_tensor_dataset(df):
    """Takes a dataframe, splits it into multiple channels and parse each channel as a tensor"""
    channel_1, channel_2, channel_3, channel_4, channel_5 = get_channels()
    
    df = (df - df.mean())/df.std()
    def as_tensor(cols):
        return torch.tensor(df[cols].values).float()
    
    return [as_tensor(channel_1), as_tensor(channel_2), as_tensor(channel_3), as_tensor(channel_4), as_tensor(channel_5)]

Defining and distributing a model: Variational Autoencoder

In this excercise we will use the Multi-channel Variational Autoencoder proposed by Antelmi et al.

!pip install -q git+https://gitlab.inria.fr/epione_ML/mcvae.git
from mcvae.models import Mcvae, ThreeLayersVAE, VAE

First, it is necessary to define a model.

dummy_data = [torch.zeros_like(x) for x in get_data_as_multichannel_tensor_dataset(df.sample())]  # Dummy data to initialize the input layer size
lat_dim = 1  # Size of the latent space for this autoencoder
vae_class = VAE  # Architecture of the autoencoder (VAE: Single layer)
model = Mcvae(data=dummy_data, lat_dim=lat_dim, vaeclass=vae_class)
model.optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
model.init_loss()

Now replicate a copy of the models across different centers.

models = [copy.deepcopy(model) for _ in range(N_CENTERS)]
n_obs_per_client = [len(client_data) for client_data in federated_dataset]

Train in a federated fashion

init_params = model.state_dict()
for round_i in range(N_ROUNDS):
    for client_dataset, client_model in zip(federated_dataset, models):
        # Load client data in the form of a tensor
        X = get_data_as_multichannel_tensor_dataset(client_dataset)

        # Load client's model parameters and train
        client_model.load_state_dict(init_params)
        client_model.optimize(epochs=N_EPOCHS, data=X)
        
    # Aggregate models using federated averaging
    trained_model = federated_averaging(models, n_obs_per_client)
    init_params = trained_model.state_dict()

Results visualization

Using the final parameters we can evaluate the performance of the model by visualizing the testing set onto the latent space.

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

sns.set()

Here we store in a list the deconding weights estimated for each modality. We are interested in the decoding weights corresponding to each dimmension of the latent space \(Z_i\).

decoding_weights_dict = {k: w.detach().numpy() for k, w in trained_model.state_dict().items() if 'W_out.weight' in k}
lat_dim_names = [f'$Z_{{{i}}}$' for i in range(lat_dim)]
col_names = lat_dim_names + ["biomarker"]
weights = pd.DataFrame()

channels = get_channels()

for channel_i, weights_i in enumerate(decoding_weights_dict.values()):

    channel_df = pd.DataFrame(np.concatenate((weights_i, np.array(channels[channel_i]).reshape(-1, 1)), axis=1), 
        columns=lat_dim_names + ["biomarker"])
    channel_df['channel'] = channel_i + 1
    

    weights = weights.append(channel_df, ignore_index=True)

weights["$Z_{0}$"] = weights["$Z_{0}$"].astype('float32')
weights.head()
$Z_{0}$ biomarker channel
0 -0.243616 WholeBrain.bl 1
1 0.189885 Ventricles.bl 1
2 -0.321737 Hippocampus.bl 1
3 -0.276269 MidTemp.bl 1
4 -0.290797 Entorhinal.bl 1

We prepare the dataset in a form so it is easily visualizable.

weights_melt = weights.melt(id_vars=['biomarker', 'channel'], var_name='latent_var')
weights_melt.sample()
biomarker channel latent_var value
17 PTAU.MEDIAN.bl 5 $Z_{0}$ 0.16061
weights
$Z_{0}$ biomarker channel
0 -0.243616 WholeBrain.bl 1
1 0.189885 Ventricles.bl 1
2 -0.321737 Hippocampus.bl 1
3 -0.276269 MidTemp.bl 1
4 -0.290797 Entorhinal.bl 1
5 -0.103445 SEX 2
6 0.143066 AGE 2
7 -0.102563 PTEDUCAT 2
8 0.274100 CDRSB.bl 3
9 0.328286 ADAS11.bl 3
10 -0.306241 MMSE.bl 3
11 -0.310553 RAVLT.immediate.bl 3
12 -0.244175 RAVLT.learning.bl 3
13 0.070348 RAVLT.forgetting.bl 3
14 0.272759 FAQ.bl 3
15 0.180560 APOE4 4
16 -0.268966 ABETA.MEDIAN.bl 5
17 0.160610 PTAU.MEDIAN.bl 5
18 0.220941 TAU.MEDIAN.bl 5
sns.catplot(data=weights_melt, x='biomarker', y='value', hue='latent_var', kind='bar', col='channel', col_wrap=1, aspect=2.5, sharex=False, palette='Blues_r')
plt.show()
../_images/federated_mcvae_adni_32_0.png