Federated Variational Autoencoders

We are going to study an example of federated latent variable modeling using federated learning and Variational autoencoders. In this example we will illustrate an iid scenario.

import copy
from tqdm.auto import tqdm

import torch
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
N_CENTERS = 4
N_ROUNDS = 10   # Number of iterations between all the centers training and the aggregation process.

N_EPOCHS = 15   # Number of epochs before aggregating
BATCH_SIZE = 48
LR = 1e-3       # Learning rate

We define a set of functions to distribute our dataset across multiple centers (split_iid) and for computing the federated averaging (federated_averaging).

def split_iid(dataset, n_centers):
    """ Split PyTorch dataset randomly into n_centers """
    n_obs_per_center = [len(dataset) // n_centers for _ in range(n_centers)]
    return random_split(dataset, n_obs_per_center)
def federated_averaging(models, n_obs_per_client):
    assert len(models) > 0, 'An empty list of models was passed.'
    assert len(n_obs_per_client) == len(models), 'List with number of observations must have ' \
                                                 'the same number of elements that list of models.'

    # Compute proportions
    n_obs = sum(n_obs_per_client)
    proportions = [n_k / n_obs for n_k in n_obs_per_client]

    # Empty model parameter dictionary
    avg_params = models[0].state_dict()
    for key, val in avg_params.items():
        avg_params[key] = torch.zeros_like(val)

    # Compute average
    for model, proportion in zip(models, proportions):
        for key in avg_params.keys():
            avg_params[key] += proportion * model.state_dict()[key]

    # Copy one of the models and load trained params
    avg_model = copy.deepcopy(models[0])
    avg_model.load_state_dict(avg_params)

    return avg_model

Federating dataset

transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0,), (1,))])
dataset = datasets.MNIST('~/data/', train=True, download=True, transform=transform)

Now, federated_dataset is a list of subsets of the main dataset.

federated_dataset = split_iid(dataset, n_centers=N_CENTERS)
print('Number of centers:', len(federated_dataset))
Number of centers: 4

Defining and distributing a model: Variational Autoencoder

In this excercise we will use the Multi-channel Variational Autoencoder proposed by Antelmi et al (ICML 2019).

!pip install -q git+https://gitlab.inria.fr/epione_ML/mcvae.git
  Building wheel for mcvae (setup.py) ... ?25l?25hdone
from mcvae.models import Mcvae, ThreeLayersVAE, VAE

First, it is necessary to define a model.

N_FEATURES = 784  # Number of pixels in MNIST
dummy_data = [torch.zeros(1, N_FEATURES)]  # Dummy data to initialize the input layer size
lat_dim = 3  # Size of the latent space for this autoencoder
vae_class = ThreeLayersVAE  # Architecture of the autoencoder (VAE: Single layer)
model = Mcvae(data=dummy_data, lat_dim=lat_dim, vaeclass=vae_class)
model.optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
model.init_loss()

Now replicate a copy of the models across different centers.

models = [copy.deepcopy(model) for _ in range(N_CENTERS)]
n_obs_per_client = [len(client_data) for client_data in federated_dataset]

Train in a federated fashion

def get_data(subset, shuffle=True):
    """ Extracts data from a Subset torch dataset in the form of a tensor"""
    loader = DataLoader(subset, batch_size=len(subset), shuffle=shuffle)
    return iter(loader).next()
init_params = model.state_dict()
for round_i in range(N_ROUNDS):
    for client_dataset, client_model in zip(federated_dataset, models):
        # Load client data in the form of a tensor
        X, y = get_data(client_dataset)
        client_model.data = [X.view(-1, N_FEATURES)]  # Set data attribute in client's model (list wraps the number of channels)

        # Load client's model parameters and train
        client_model.load_state_dict(init_params)
        client_model.optimize(epochs=N_EPOCHS, data=client_model.data)
        
    # Aggregate models using federated averaging
    trained_model = federated_averaging(models, n_obs_per_client)
    init_params = trained_model.state_dict()
====> Epoch:    0/15 (0%)	Loss: 528.5617	LL: -528.5566	KL: 0.0050	LL/KL: -104751.2169
====> Epoch:   10/15 (67%)	Loss: 77.4644	LL: -74.0843	KL: 3.3801	LL/KL: -21.9175
====> Epoch:    0/15 (0%)	Loss: 527.9626	LL: -527.9576	KL: 0.0051	LL/KL: -104216.8695
====> Epoch:   10/15 (67%)	Loss: 77.4401	LL: -73.8392	KL: 3.6009	LL/KL: -20.5055
====> Epoch:    0/15 (0%)	Loss: 525.0997	LL: -525.0947	KL: 0.0051	LL/KL: -103884.9293
====> Epoch:   10/15 (67%)	Loss: 76.7743	LL: -73.2086	KL: 3.5656	LL/KL: -20.5319
====> Epoch:    0/15 (0%)	Loss: 528.4604	LL: -528.4553	KL: 0.0051	LL/KL: -104527.3463
====> Epoch:   10/15 (67%)	Loss: 77.8238	LL: -74.3751	KL: 3.4487	LL/KL: -21.5661
====> Epoch:   20/30 (67%)	Loss: 49.0290	LL: -37.6401	KL: 11.3889	LL/KL: -3.3050
====> Epoch:   20/30 (67%)	Loss: 48.4610	LL: -37.1718	KL: 11.2892	LL/KL: -3.2927
====> Epoch:   20/30 (67%)	Loss: 48.2056	LL: -36.9949	KL: 11.2106	LL/KL: -3.3000
====> Epoch:   20/30 (67%)	Loss: 49.2486	LL: -37.8945	KL: 11.3541	LL/KL: -3.3375
====> Epoch:   30/45 (67%)	Loss: 15.8485	LL: -9.6083	KL: 6.2402	LL/KL: -1.5397
====> Epoch:   40/45 (89%)	Loss: -12.3623	LL: 18.9352	KL: 6.5729	LL/KL: 2.8808
====> Epoch:   30/45 (67%)	Loss: 15.8072	LL: -9.5787	KL: 6.2285	LL/KL: -1.5379
====> Epoch:   40/45 (89%)	Loss: -12.5916	LL: 19.1550	KL: 6.5635	LL/KL: 2.9184
====> Epoch:   30/45 (67%)	Loss: 16.1558	LL: -9.9610	KL: 6.1948	LL/KL: -1.6080
====> Epoch:   40/45 (89%)	Loss: -11.9684	LL: 18.5574	KL: 6.5890	LL/KL: 2.8164
====> Epoch:   30/45 (67%)	Loss: 16.5887	LL: -10.3528	KL: 6.2359	LL/KL: -1.6602
====> Epoch:   40/45 (89%)	Loss: -11.1709	LL: 17.6888	KL: 6.5179	LL/KL: 2.7139
====> Epoch:   50/60 (83%)	Loss: -43.2949	LL: 51.9390	KL: 8.6441	LL/KL: 6.0086
====> Epoch:   50/60 (83%)	Loss: -43.9839	LL: 52.5778	KL: 8.5938	LL/KL: 6.1181
====> Epoch:   50/60 (83%)	Loss: -43.4652	LL: 51.9832	KL: 8.5180	LL/KL: 6.1028
====> Epoch:   50/60 (83%)	Loss: -42.8195	LL: 51.4198	KL: 8.6003	LL/KL: 5.9789
====> Epoch:   60/75 (80%)	Loss: -63.5757	LL: 71.6305	KL: 8.0549	LL/KL: 8.8928
====> Epoch:   70/75 (93%)	Loss: -77.9186	LL: 86.5862	KL: 8.6676	LL/KL: 9.9896
====> Epoch:   60/75 (80%)	Loss: -64.2308	LL: 72.2830	KL: 8.0523	LL/KL: 8.9767
====> Epoch:   70/75 (93%)	Loss: -79.0930	LL: 87.7227	KL: 8.6297	LL/KL: 10.1652
====> Epoch:   60/75 (80%)	Loss: -63.4257	LL: 71.4438	KL: 8.0180	LL/KL: 8.9104
====> Epoch:   70/75 (93%)	Loss: -78.2753	LL: 87.0201	KL: 8.7449	LL/KL: 9.9510
====> Epoch:   60/75 (80%)	Loss: -63.2790	LL: 71.3351	KL: 8.0560	LL/KL: 8.8549
====> Epoch:   70/75 (93%)	Loss: -77.4719	LL: 86.1749	KL: 8.7030	LL/KL: 9.9018
====> Epoch:   80/90 (89%)	Loss: -92.5114	LL: 100.9692	KL: 8.4578	LL/KL: 11.9380
====> Epoch:   80/90 (89%)	Loss: -92.9820	LL: 101.3774	KL: 8.3955	LL/KL: 12.0753
====> Epoch:   80/90 (89%)	Loss: -92.5196	LL: 100.8786	KL: 8.3591	LL/KL: 12.0682
====> Epoch:   80/90 (89%)	Loss: -92.8745	LL: 101.2884	KL: 8.4139	LL/KL: 12.0382
====> Epoch:   90/105 (86%)	Loss: -106.0539	LL: 114.7090	KL: 8.6551	LL/KL: 13.2534
====> Epoch:  100/105 (95%)	Loss: -117.9297	LL: 126.6995	KL: 8.7698	LL/KL: 14.4472
====> Epoch:   90/105 (86%)	Loss: -106.8086	LL: 115.4679	KL: 8.6593	LL/KL: 13.3345
====> Epoch:  100/105 (95%)	Loss: -118.4944	LL: 127.2676	KL: 8.7733	LL/KL: 14.5063
====> Epoch:   90/105 (86%)	Loss: -106.5121	LL: 115.1398	KL: 8.6277	LL/KL: 13.3453
====> Epoch:  100/105 (95%)	Loss: -118.4516	LL: 127.2390	KL: 8.7874	LL/KL: 14.4797
====> Epoch:   90/105 (86%)	Loss: -106.7659	LL: 115.4336	KL: 8.6677	LL/KL: 13.3177
====> Epoch:  100/105 (95%)	Loss: -118.5240	LL: 127.3401	KL: 8.8160	LL/KL: 14.4442
====> Epoch:  110/120 (92%)	Loss: -126.0735	LL: 134.7997	KL: 8.7261	LL/KL: 15.4478
====> Epoch:  110/120 (92%)	Loss: -128.5597	LL: 137.4358	KL: 8.8761	LL/KL: 15.4838
====> Epoch:  110/120 (92%)	Loss: -128.5418	LL: 137.4097	KL: 8.8679	LL/KL: 15.4952
====> Epoch:  110/120 (92%)	Loss: -128.5736	LL: 137.4403	KL: 8.8667	LL/KL: 15.5007
====> Epoch:  120/135 (89%)	Loss: -136.8414	LL: 145.7950	KL: 8.9536	LL/KL: 16.2834
====> Epoch:  130/135 (96%)	Loss: -145.7557	LL: 154.8592	KL: 9.1035	LL/KL: 17.0110
====> Epoch:  120/135 (89%)	Loss: -137.7465	LL: 146.7102	KL: 8.9637	LL/KL: 16.3671
====> Epoch:  130/135 (96%)	Loss: -147.1581	LL: 156.2683	KL: 9.1102	LL/KL: 17.1531
====> Epoch:  120/135 (89%)	Loss: -137.8340	LL: 146.7747	KL: 8.9407	LL/KL: 16.4166
====> Epoch:  130/135 (96%)	Loss: -146.9022	LL: 155.9727	KL: 9.0705	LL/KL: 17.1955
====> Epoch:  120/135 (89%)	Loss: -137.7606	LL: 146.7392	KL: 8.9785	LL/KL: 16.3433
====> Epoch:  130/135 (96%)	Loss: -146.8978	LL: 155.9755	KL: 9.0777	LL/KL: 17.1823
====> Epoch:  140/150 (93%)	Loss: -153.9218	LL: 163.0297	KL: 9.1080	LL/KL: 17.8997
====> Epoch:  140/150 (93%)	Loss: -154.8601	LL: 164.0375	KL: 9.1775	LL/KL: 17.8740
====> Epoch:  140/150 (93%)	Loss: -155.0774	LL: 164.2348	KL: 9.1574	LL/KL: 17.9346
====> Epoch:  140/150 (93%)	Loss: -154.5916	LL: 163.7941	KL: 9.2025	LL/KL: 17.7989

Results visualization

Using the final parameters we can evaluate the performance of the model by visualizing the testing set onto the latent space.

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

sns.set()
dataset_test = datasets.MNIST('~/data/', train=False, download=True)
X_test, y_test = [dataset_test.data.view(-1, N_FEATURES).float()], dataset_test.targets
Z_test = np.hstack([z.loc.detach().numpy() for z in trained_model.encode(X_test)])
col_names = [f'$Z_{i}$' for i in range(Z_test.shape[1])]
latent_df = pd.DataFrame(Z_test, columns=col_names)
latent_df['label'] = y_test
latent_df['label'] = latent_df['label'].astype('category')
latent_df.head()
$Z_0$ $Z_1$ $Z_2$ label
0 794.841309 -325.295868 174.970810 7
1 -357.250793 576.667358 -40.315079 2
2 -131.488586 354.551086 350.653107 1
3 -160.370346 -291.653625 -659.524414 0
4 160.157776 -210.283234 -121.650696 4
sns.pairplot(latent_df, hue='label', corner=True)
plt.show()
../_images/federated_mcvae_27_0.png