Simulation of Federated Learning

We present step-by-step description of how to simulate the federated learning on MNIST data.

Installation

To this end, we first make sure that the required dependencies are installed.

[1]:
# !pip install "appfl[analytics,examples]"

You can also install the package from the Github repository.

[2]:
# !git clone git@github.com:APPFL/APPFL.git
# !cd APPFL
# !pip install -e ".[analytics,examples]"

Import dependencies

We put all the imports here. Our framework appfl is backboned by torch and its neural network model torch.nn. We also import torchvision to download the MNIST dataset.

[3]:
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor

import appfl.run as ppfl
from appfl.config import *
from appfl.misc.data import Dataset

Train datasets

Since this is a simulation of federated learning, we manually split the training datasets. Note, however, that this is not necessary in practice. In this example, we consider only two clients in the simulation. But, we can set num_clients to a larger value for more clients.

[4]:
num_clients = 2

Each client needs to create Dataset object with the training data. Here, we create the objects for all the clients.

[5]:
train_data_raw = torchvision.datasets.MNIST(
    "./_data", train=True, download=True, transform=ToTensor()
)
split_train_data_raw = np.array_split(range(len(train_data_raw)), num_clients)
train_datasets = []
for i in range(num_clients):

    train_data_input = []
    train_data_label = []
    for idx in split_train_data_raw[i]:
        train_data_input.append(train_data_raw[idx][0].tolist())
        train_data_label.append(train_data_raw[idx][1])

    train_datasets.append(
        Dataset(
            torch.FloatTensor(train_data_input),
            torch.tensor(train_data_label),
        )
    )
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./_data/MNIST/raw/train-images-idx3-ubyte.gz
9913344it [00:00, 32348560.75it/s]
Extracting ./_data/MNIST/raw/train-images-idx3-ubyte.gz to ./_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./_data/MNIST/raw/train-labels-idx1-ubyte.gz
29696it [00:00, 14584783.56it/s]
Extracting ./_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./_data/MNIST/raw/t10k-images-idx3-ubyte.gz

1649664it [00:00, 18256781.30it/s]
Extracting ./_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
5120it [00:00, 8140574.86it/s]
Extracting ./_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./_data/MNIST/raw

Test dataset

The test data also needs to be wrapped in Dataset object.

[6]:
test_data_raw = torchvision.datasets.MNIST(
    "./_data", train=False, download=False, transform=ToTensor()
)
test_data_input = []
test_data_label = []
for idx in range(len(test_data_raw)):
    test_data_input.append(test_data_raw[idx][0].tolist())
    test_data_label.append(test_data_raw[idx][1])

test_dataset = Dataset(
    torch.FloatTensor(test_data_input), torch.tensor(test_data_label)
)

User-defined model

Users can define their own models by deriving torch.nn.Module. For example in this simulation, we define the following convolutional neural network. The loss function is set to be torch.nn.CrossEntropyLoss().

[7]:
class CNN(nn.Module):
    def __init__(self, num_channel=1, num_classes=10, num_pixel=28):
        super().__init__()
        self.conv1 = nn.Conv2d(
            num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True
        )
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))
        self.act = nn.ReLU(inplace=True)

        X = num_pixel
        X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
        X = X / 2
        X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
        X = X / 2
        X = int(X)

        self.fc1 = nn.Linear(64 * X * X, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.maxpool(x)
        x = self.act(self.conv2(x))
        x = self.maxpool(x)
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN()
loss_fn = torch.nn.CrossEntropyLoss()

Runs with configuration

We run the appfl training with the data and model defined above. A number of parameters can be easily set by changing the configuration values.

We read the configuration from appfl.config.Config class, which is stored in a dictionary.

[8]:
cfg = OmegaConf.structured(Config)
print(OmegaConf.to_yaml(cfg))
fed:
  type: fedavg
  servername: FedAvgServer
  clientname: FedAvgClient
  args:
    num_local_epochs: 1
    optim: SGD
    optim_args:
      lr: 0.01
      momentum: 0.9
      weight_decay: 1.0e-05
    epsilon: false
    clip_value: false
    clip_norm: 1
num_epochs: 2
batch_training: false
train_data_batch_size: 64
train_data_shuffle: false
test_data_batch_size: 64
test_data_shuffle: false
result_dir: ./results
device: cpu
validation: true
max_message_size: 10485760
client:
  id: 1
server:
  id: 1
  host: localhost
  port: 50051

And, we can start training with the configuration cfg.

[9]:
ppfl.run_serial(cfg, model, train_datasets, test_dataset, "MNIST")
        Iter     Local[s]    Global[s]     Valid[s]      Iter[s]   Elapsed[s]  TestAvgLoss TestAccuracy
           1        41.82         0.00         2.04        43.87        43.87     2.298913        13.43
           2        39.65         0.00         1.94        41.60        85.47     2.298292        13.57