Federated Variational Autoencoders¶
We are going to study an example of federated latent variable modeling using federated learning and Variational autoencoders. In this example we will illustrate an iid scenario.
import copy
from tqdm.auto import tqdm
import torch
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
N_CENTERS = 4
N_ROUNDS = 10 # Number of iterations between all the centers training and the aggregation process.
N_EPOCHS = 15 # Number of epochs before aggregating
BATCH_SIZE = 48
LR = 1e-3 # Learning rate
We define a set of functions to distribute our dataset across multiple centers (split_iid
) and for computing the federated averaging (federated_averaging
).
def split_iid(dataset, n_centers):
""" Split PyTorch dataset randomly into n_centers """
n_obs_per_center = [len(dataset) // n_centers for _ in range(n_centers)]
return random_split(dataset, n_obs_per_center)
def federated_averaging(models, n_obs_per_client):
assert len(models) > 0, 'An empty list of models was passed.'
assert len(n_obs_per_client) == len(models), 'List with number of observations must have ' \
'the same number of elements that list of models.'
# Compute proportions
n_obs = sum(n_obs_per_client)
proportions = [n_k / n_obs for n_k in n_obs_per_client]
# Empty model parameter dictionary
avg_params = models[0].state_dict()
for key, val in avg_params.items():
avg_params[key] = torch.zeros_like(val)
# Compute average
for model, proportion in zip(models, proportions):
for key in avg_params.keys():
avg_params[key] += proportion * model.state_dict()[key]
# Copy one of the models and load trained params
avg_model = copy.deepcopy(models[0])
avg_model.load_state_dict(avg_params)
return avg_model
Federating dataset¶
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0,), (1,))])
dataset = datasets.MNIST('~/data/', train=True, download=True, transform=transform)
Now, federated_dataset
is a list of subsets of the main dataset.
federated_dataset = split_iid(dataset, n_centers=N_CENTERS)
print('Number of centers:', len(federated_dataset))
Number of centers: 4
Defining and distributing a model: Variational Autoencoder¶
In this excercise we will use the Multi-channel Variational Autoencoder proposed by Antelmi et al (ICML 2019).
!pip install -q git+https://gitlab.inria.fr/epione_ML/mcvae.git
Building wheel for mcvae (setup.py) ... ?25l?25hdone
from mcvae.models import Mcvae, ThreeLayersVAE, VAE
First, it is necessary to define a model.
N_FEATURES = 784 # Number of pixels in MNIST
dummy_data = [torch.zeros(1, N_FEATURES)] # Dummy data to initialize the input layer size
lat_dim = 3 # Size of the latent space for this autoencoder
vae_class = ThreeLayersVAE # Architecture of the autoencoder (VAE: Single layer)
model = Mcvae(data=dummy_data, lat_dim=lat_dim, vaeclass=vae_class)
model.optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
model.init_loss()
Now replicate a copy of the models across different centers.
models = [copy.deepcopy(model) for _ in range(N_CENTERS)]
n_obs_per_client = [len(client_data) for client_data in federated_dataset]
Train in a federated fashion
def get_data(subset, shuffle=True):
""" Extracts data from a Subset torch dataset in the form of a tensor"""
loader = DataLoader(subset, batch_size=len(subset), shuffle=shuffle)
return iter(loader).next()
init_params = model.state_dict()
for round_i in range(N_ROUNDS):
for client_dataset, client_model in zip(federated_dataset, models):
# Load client data in the form of a tensor
X, y = get_data(client_dataset)
client_model.data = [X.view(-1, N_FEATURES)] # Set data attribute in client's model (list wraps the number of channels)
# Load client's model parameters and train
client_model.load_state_dict(init_params)
client_model.optimize(epochs=N_EPOCHS, data=client_model.data)
# Aggregate models using federated averaging
trained_model = federated_averaging(models, n_obs_per_client)
init_params = trained_model.state_dict()
====> Epoch: 0/15 (0%) Loss: 528.5617 LL: -528.5566 KL: 0.0050 LL/KL: -104751.2169
====> Epoch: 10/15 (67%) Loss: 77.4644 LL: -74.0843 KL: 3.3801 LL/KL: -21.9175
====> Epoch: 0/15 (0%) Loss: 527.9626 LL: -527.9576 KL: 0.0051 LL/KL: -104216.8695
====> Epoch: 10/15 (67%) Loss: 77.4401 LL: -73.8392 KL: 3.6009 LL/KL: -20.5055
====> Epoch: 0/15 (0%) Loss: 525.0997 LL: -525.0947 KL: 0.0051 LL/KL: -103884.9293
====> Epoch: 10/15 (67%) Loss: 76.7743 LL: -73.2086 KL: 3.5656 LL/KL: -20.5319
====> Epoch: 0/15 (0%) Loss: 528.4604 LL: -528.4553 KL: 0.0051 LL/KL: -104527.3463
====> Epoch: 10/15 (67%) Loss: 77.8238 LL: -74.3751 KL: 3.4487 LL/KL: -21.5661
====> Epoch: 20/30 (67%) Loss: 49.0290 LL: -37.6401 KL: 11.3889 LL/KL: -3.3050
====> Epoch: 20/30 (67%) Loss: 48.4610 LL: -37.1718 KL: 11.2892 LL/KL: -3.2927
====> Epoch: 20/30 (67%) Loss: 48.2056 LL: -36.9949 KL: 11.2106 LL/KL: -3.3000
====> Epoch: 20/30 (67%) Loss: 49.2486 LL: -37.8945 KL: 11.3541 LL/KL: -3.3375
====> Epoch: 30/45 (67%) Loss: 15.8485 LL: -9.6083 KL: 6.2402 LL/KL: -1.5397
====> Epoch: 40/45 (89%) Loss: -12.3623 LL: 18.9352 KL: 6.5729 LL/KL: 2.8808
====> Epoch: 30/45 (67%) Loss: 15.8072 LL: -9.5787 KL: 6.2285 LL/KL: -1.5379
====> Epoch: 40/45 (89%) Loss: -12.5916 LL: 19.1550 KL: 6.5635 LL/KL: 2.9184
====> Epoch: 30/45 (67%) Loss: 16.1558 LL: -9.9610 KL: 6.1948 LL/KL: -1.6080
====> Epoch: 40/45 (89%) Loss: -11.9684 LL: 18.5574 KL: 6.5890 LL/KL: 2.8164
====> Epoch: 30/45 (67%) Loss: 16.5887 LL: -10.3528 KL: 6.2359 LL/KL: -1.6602
====> Epoch: 40/45 (89%) Loss: -11.1709 LL: 17.6888 KL: 6.5179 LL/KL: 2.7139
====> Epoch: 50/60 (83%) Loss: -43.2949 LL: 51.9390 KL: 8.6441 LL/KL: 6.0086
====> Epoch: 50/60 (83%) Loss: -43.9839 LL: 52.5778 KL: 8.5938 LL/KL: 6.1181
====> Epoch: 50/60 (83%) Loss: -43.4652 LL: 51.9832 KL: 8.5180 LL/KL: 6.1028
====> Epoch: 50/60 (83%) Loss: -42.8195 LL: 51.4198 KL: 8.6003 LL/KL: 5.9789
====> Epoch: 60/75 (80%) Loss: -63.5757 LL: 71.6305 KL: 8.0549 LL/KL: 8.8928
====> Epoch: 70/75 (93%) Loss: -77.9186 LL: 86.5862 KL: 8.6676 LL/KL: 9.9896
====> Epoch: 60/75 (80%) Loss: -64.2308 LL: 72.2830 KL: 8.0523 LL/KL: 8.9767
====> Epoch: 70/75 (93%) Loss: -79.0930 LL: 87.7227 KL: 8.6297 LL/KL: 10.1652
====> Epoch: 60/75 (80%) Loss: -63.4257 LL: 71.4438 KL: 8.0180 LL/KL: 8.9104
====> Epoch: 70/75 (93%) Loss: -78.2753 LL: 87.0201 KL: 8.7449 LL/KL: 9.9510
====> Epoch: 60/75 (80%) Loss: -63.2790 LL: 71.3351 KL: 8.0560 LL/KL: 8.8549
====> Epoch: 70/75 (93%) Loss: -77.4719 LL: 86.1749 KL: 8.7030 LL/KL: 9.9018
====> Epoch: 80/90 (89%) Loss: -92.5114 LL: 100.9692 KL: 8.4578 LL/KL: 11.9380
====> Epoch: 80/90 (89%) Loss: -92.9820 LL: 101.3774 KL: 8.3955 LL/KL: 12.0753
====> Epoch: 80/90 (89%) Loss: -92.5196 LL: 100.8786 KL: 8.3591 LL/KL: 12.0682
====> Epoch: 80/90 (89%) Loss: -92.8745 LL: 101.2884 KL: 8.4139 LL/KL: 12.0382
====> Epoch: 90/105 (86%) Loss: -106.0539 LL: 114.7090 KL: 8.6551 LL/KL: 13.2534
====> Epoch: 100/105 (95%) Loss: -117.9297 LL: 126.6995 KL: 8.7698 LL/KL: 14.4472
====> Epoch: 90/105 (86%) Loss: -106.8086 LL: 115.4679 KL: 8.6593 LL/KL: 13.3345
====> Epoch: 100/105 (95%) Loss: -118.4944 LL: 127.2676 KL: 8.7733 LL/KL: 14.5063
====> Epoch: 90/105 (86%) Loss: -106.5121 LL: 115.1398 KL: 8.6277 LL/KL: 13.3453
====> Epoch: 100/105 (95%) Loss: -118.4516 LL: 127.2390 KL: 8.7874 LL/KL: 14.4797
====> Epoch: 90/105 (86%) Loss: -106.7659 LL: 115.4336 KL: 8.6677 LL/KL: 13.3177
====> Epoch: 100/105 (95%) Loss: -118.5240 LL: 127.3401 KL: 8.8160 LL/KL: 14.4442
====> Epoch: 110/120 (92%) Loss: -126.0735 LL: 134.7997 KL: 8.7261 LL/KL: 15.4478
====> Epoch: 110/120 (92%) Loss: -128.5597 LL: 137.4358 KL: 8.8761 LL/KL: 15.4838
====> Epoch: 110/120 (92%) Loss: -128.5418 LL: 137.4097 KL: 8.8679 LL/KL: 15.4952
====> Epoch: 110/120 (92%) Loss: -128.5736 LL: 137.4403 KL: 8.8667 LL/KL: 15.5007
====> Epoch: 120/135 (89%) Loss: -136.8414 LL: 145.7950 KL: 8.9536 LL/KL: 16.2834
====> Epoch: 130/135 (96%) Loss: -145.7557 LL: 154.8592 KL: 9.1035 LL/KL: 17.0110
====> Epoch: 120/135 (89%) Loss: -137.7465 LL: 146.7102 KL: 8.9637 LL/KL: 16.3671
====> Epoch: 130/135 (96%) Loss: -147.1581 LL: 156.2683 KL: 9.1102 LL/KL: 17.1531
====> Epoch: 120/135 (89%) Loss: -137.8340 LL: 146.7747 KL: 8.9407 LL/KL: 16.4166
====> Epoch: 130/135 (96%) Loss: -146.9022 LL: 155.9727 KL: 9.0705 LL/KL: 17.1955
====> Epoch: 120/135 (89%) Loss: -137.7606 LL: 146.7392 KL: 8.9785 LL/KL: 16.3433
====> Epoch: 130/135 (96%) Loss: -146.8978 LL: 155.9755 KL: 9.0777 LL/KL: 17.1823
====> Epoch: 140/150 (93%) Loss: -153.9218 LL: 163.0297 KL: 9.1080 LL/KL: 17.8997
====> Epoch: 140/150 (93%) Loss: -154.8601 LL: 164.0375 KL: 9.1775 LL/KL: 17.8740
====> Epoch: 140/150 (93%) Loss: -155.0774 LL: 164.2348 KL: 9.1574 LL/KL: 17.9346
====> Epoch: 140/150 (93%) Loss: -154.5916 LL: 163.7941 KL: 9.2025 LL/KL: 17.7989
Results visualization¶
Using the final parameters we can evaluate the performance of the model by visualizing the testing set onto the latent space.
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()
dataset_test = datasets.MNIST('~/data/', train=False, download=True)
X_test, y_test = [dataset_test.data.view(-1, N_FEATURES).float()], dataset_test.targets
Z_test = np.hstack([z.loc.detach().numpy() for z in trained_model.encode(X_test)])
col_names = [f'$Z_{i}$' for i in range(Z_test.shape[1])]
latent_df = pd.DataFrame(Z_test, columns=col_names)
latent_df['label'] = y_test
latent_df['label'] = latent_df['label'].astype('category')
latent_df.head()
$Z_0$ | $Z_1$ | $Z_2$ | label | |
---|---|---|---|---|
0 | 794.841309 | -325.295868 | 174.970810 | 7 |
1 | -357.250793 | 576.667358 | -40.315079 | 2 |
2 | -131.488586 | 354.551086 | 350.653107 | 1 |
3 | -160.370346 | -291.653625 | -659.524414 | 0 |
4 | 160.157776 | -210.283234 | -121.650696 | 4 |
sns.pairplot(latent_df, hue='label', corner=True)
plt.show()