Launch gRPC client
We present how to launch a gRPC client. To pair with the server notebook, we consider only one client.
[1]:
num_clients = 1
Import dependencies
Everything is the same as for the gRPC server. But here, we need to import appfl.run_grpc_client module.
[2]:
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor
from appfl.config import *
from appfl.misc.data import *
import appfl.run_grpc_client as grpc_client
Train datasets
Since this is a simulation of federated learning, we manually split the training datasets. Note, however, that this is not necessary in practice. Each client needs to create Dataset object with the training data. Here, we create the objects for all the clients.
[3]:
train_data_raw = torchvision.datasets.MNIST(
"./_data", train=True, download=True, transform=ToTensor()
)
split_train_data_raw = np.array_split(range(len(train_data_raw)), num_clients)
train_datasets = []
for i in range(num_clients):
train_data_input = []
train_data_label = []
for idx in split_train_data_raw[i]:
train_data_input.append(train_data_raw[idx][0].tolist())
train_data_label.append(train_data_raw[idx][1])
train_datasets.append(
Dataset(
torch.FloatTensor(train_data_input),
torch.tensor(train_data_label),
)
)
User-defined model
We should use the same model used in the server. See the notebook for server.
[4]:
class CNN(nn.Module):
def __init__(self, num_channel=1, num_classes=10, num_pixel=28):
super().__init__()
self.conv1 = nn.Conv2d(
num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True
)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)
self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.act = nn.ReLU(inplace=True)
X = num_pixel
X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
X = X / 2
X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
X = X / 2
X = int(X)
self.fc1 = nn.Linear(64 * X * X, 512)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = self.act(self.conv1(x))
x = self.maxpool(x)
x = self.act(self.conv2(x))
x = self.maxpool(x)
x = torch.flatten(x, 1)
x = self.act(self.fc1(x))
x = self.fc2(x)
return x
model = CNN()
Runs with configuration
We run the appfl training with the data and model defined above. A number of parameters can be easily set by changing the configuration values. We read the configuration from appfl.config.Config class, which is stored in a dictionary.
[5]:
cfg = OmegaConf.structured(Config)
print(OmegaConf.to_yaml(cfg))
fed:
type: fedavg
servername: FedAvgServer
clientname: FedAvgClient
args:
num_local_epochs: 1
optim: SGD
optim_args:
lr: 0.01
momentum: 0.9
weight_decay: 1.0e-05
epsilon: false
clip_value: false
clip_norm: 1
num_epochs: 2
batch_training: true
train_data_batch_size: 64
train_data_shuffle: false
test_data_batch_size: 64
test_data_shuffle: false
result_dir: ./results
device: cpu
validation: true
max_message_size: 10485760
operator:
id: 1
server:
id: 1
host: localhost
port: 50051
client:
id: 1
Make sure that we see some client-side logs…
[6]:
import sys
import logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
And, we can start training with the configuration cfg.
[7]:
grpc_client.run_client(cfg, 1, model, train_datasets[0])
INFO:appfl.protos.client:[Client ID: 00] Received JobReponse with (server,round,job)=(1,1,2)
INFO:appfl.run_grpc_client:[Client ID: 00 Round #: 01] Start training
INFO:appfl.run_grpc_client:[Client ID: 00 Round #: 01] Trained (Elapsed 36.7952) and sent results back to the server (Elapsed 1.9958)
INFO:appfl.protos.client:[Client ID: 00] Received JobReponse with (server,round,job)=(1,2,2)
INFO:appfl.run_grpc_client:[Client ID: 00 Round #: 02] Start training
INFO:appfl.run_grpc_client:[Client ID: 00 Round #: 02] Trained (Elapsed 36.7490) and sent results back to the server (Elapsed 2.0801)
INFO:appfl.protos.client:[Client ID: 00] Received JobReponse with (server,round,job)=(1,2,3)
INFO:appfl.run_grpc_client:[Client ID: 00 Round #: 02] Quitting... Learning 73.5442 Sending 4.0734 Receiving 0.0519 Job 0.0041 Total 4.1294