How to add new algorithms
Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client. Implementation of the new classes should be derived from the following two base classes:
- class appfl.algorithm.BaseServer(weights: OrderedDict, model: torch.nn.Module, loss_fn: torch.nn.Module, num_clients: int, device)[source]
Abstract class of PPFL algorithm for server that aggregates and updates model parameters.
- Parameters:
weight (Dict) – aggregation weight assigned to each client
model (nn.Module) – torch neural network model to train
loss_fn (nn.Module) – loss function
num_clients (int) – the number of clients
device (str) – device for computation
- class appfl.algorithm.BaseClient(id: int, weight: Dict, model: torch.nn.Module, loss_fn: torch.nn.Module, dataloader: torch.utils.data.DataLoader, cfg, outfile, test_dataloader)[source]
Abstract class of PPFL algorithm for client that trains local model.
- Parameters:
id – unique ID for each client
weight – aggregation weight assigned to each client
model – (nn.Module): torch neural network model to train
loss_fn (nn.Module) – loss function
dataloader – PyTorch data loader
device (str) – device for computation
Example: NewAlgo
Here we give some simple example.
Core algorithm class
We first create classes for the global and local updates in appfl/algorithm:
See two classes
NewAlgoServerandNewAlgoClientinnewalgo.pyIn
NewAlgoServer, theupdatefunction conducts a global update by averaging the local model parameters sent from multiple clientsIn
NewAlgoClient, theupdatefunction conducts a local update and send the resulting local model parameters to the server
This is an example code:
src/appfl/algorithm/newalgo.pyfrom .algorithm import BaseServer, BaseClient
class NewAlgoServer(BaseServer):
def __init__(self, weights, model, num_clients, device, **kwargs):
super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self, local_states: OrderedDict):
# Implement new server update function
class NewAlgoClient(BaseClient):
def __init__(self, id, weight, model, dataloader, device, **kwargs):
super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self):
# Implement new client update function
Configuration dataclass
The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed.
Let’s say we add src/appfl/config/fed/newalgo.py file to implement the dataclass as follows:
src/appfl/config/fed/newalgo.pyfrom dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
@dataclass
class NewAlgo:
type: str = "newalgo"
servername: str = "NewAlgoServer"
clientname: str = "NewAlgoClient"
args: DictConfig = OmegaConf.create(
{
# add new arguments
}
)
Then, we need to add the following line to the main configuration file config.py.
from .fed.new_algorithm import *
This is the main configuration class in src/appfl/config/config.py.
Each algorithm, specified in Config.fed, can be configured in the dataclasses at appfl.config.fed.*.
1from dataclasses import dataclass, field
2from typing import Any
3from omegaconf import DictConfig, OmegaConf
4
5
6from .fed.federated import *
7from .fed.iceadmm import * ## TODO: combine iceadmm and iiadmm under the name of ADMM.
8from .fed.iiadmm import *
9
10
11@dataclass
12class Config:
13 fed: Any = Federated()
14
15 # Compute device
16 device: str = "cpu"
17
18 # Number of training epochs
19 num_clients: int = 1
20
21 # Number of training epochs
22 num_epochs: int = 2
23
24 # Number of workers in DataLoader
25 num_workers: int = 0
26
27 # Train data batch info
28 batch_training: bool = True ## TODO: revisit
29 train_data_batch_size: int = 64
30 train_data_shuffle: bool = False
31
32 # Indication of whether to validate or not using testing data
33 validation: bool = True
34 test_data_batch_size: int = 64
35 test_data_shuffle: bool = False
36
37 # Checking data sanity
38 data_sanity: bool = False
39
40 # Reproducibility
41 reproduce: bool = True
42
43 # PCA on Trajectory
44 pca_dir: str = ""
45 params_start: int=0
46 params_end: int=49
47 ncomponents: int=40
48
49 # Tensorboard
50 use_tensorboard: bool = False
51
52 # Loading models
53 load_model: bool = False
54 load_model_dirname: str = ""
55 load_model_filename: str = ""
56
57 # Saving models (server)
58 save_model: bool = False
59 save_model_dirname: str = ""
60 save_model_filename: str = ""
61 checkpoints_interval: int = 2
62
63 # Saving state_dict (clients)
64 save_model_state_dict: bool = False
65
66 # Logging and recording outputs
67 output_dirname: str = "output"
68 output_filename: str = "result"
69
70 logginginfo: DictConfig = OmegaConf.create({})
71 summary_file: str = ""
72
73
74 #
75 # gRPC configutations
76 #
77
78 # 100 MB for gRPC maximum message size
79 max_message_size: int = 104857600
80
81 operator: DictConfig = OmegaConf.create({"id": 1})
82 server: DictConfig = OmegaConf.create(
83 {"id": 1, "host": "localhost", "port": 50051, "use_tls": False, "api_key": None}
84 )
85 client: DictConfig = OmegaConf.create({"id": 1})