How to add new algorithms

Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client. Implementation of the new classes should be derived from the following two base classes:

class appfl.algorithm.BaseServer(weights: OrderedDict, model: torch.nn.Module, loss_fn: torch.nn.Module, num_clients: int, device)[source]

Abstract class of PPFL algorithm for server that aggregates and updates model parameters.

Parameters:
  • weight (Dict) – aggregation weight assigned to each client

  • model (nn.Module) – torch neural network model to train

  • loss_fn (nn.Module) – loss function

  • num_clients (int) – the number of clients

  • device (str) – device for computation

get_model() torch.nn.Module[source]

Get the model

Returns:

a deepcopy of self.model

Return type:

nn.Module

class appfl.algorithm.BaseClient(id: int, weight: Dict, model: torch.nn.Module, loss_fn: torch.nn.Module, dataloader: torch.utils.data.DataLoader, cfg, outfile, test_dataloader)[source]

Abstract class of PPFL algorithm for client that trains local model.

Parameters:
  • id – unique ID for each client

  • weight – aggregation weight assigned to each client

  • model – (nn.Module): torch neural network model to train

  • loss_fn (nn.Module) – loss function

  • dataloader – PyTorch data loader

  • device (str) – device for computation

get_model()[source]

Get the model

Returns:

the state_dict of local model

laplace_mechanism_output_perturb(scale_value)[source]

Differential privacy for output perturbation based on Laplacian distribution. This output perturbation adds Laplace noise to primal_state.

Parameters:

scale_value – scaling vector to control the variance of Laplacian distribution

update()[source]

Update local model parameters

Example: NewAlgo

Here we give some simple example.

Core algorithm class

We first create classes for the global and local updates in appfl/algorithm:

  • See two classes NewAlgoServer and NewAlgoClient in newalgo.py

  • In NewAlgoServer, the update function conducts a global update by averaging the local model parameters sent from multiple clients

  • In NewAlgoClient, the update function conducts a local update and send the resulting local model parameters to the server

This is an example code:

Example code for src/appfl/algorithm/newalgo.py
from .algorithm import BaseServer, BaseClient

class NewAlgoServer(BaseServer):
    def __init__(self, weights, model, num_clients, device, **kwargs):
        super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self, local_states: OrderedDict):
        # Implement new server update function

class NewAlgoClient(BaseClient):
    def __init__(self, id, weight, model, dataloader, device, **kwargs):
        super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self):
        # Implement new client update function

Configuration dataclass

The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed. Let’s say we add src/appfl/config/fed/newalgo.py file to implement the dataclass as follows:

Example code for src/appfl/config/fed/newalgo.py
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

@dataclass
class NewAlgo:
    type: str = "newalgo"
    servername: str = "NewAlgoServer"
    clientname: str = "NewAlgoClient"
    args: DictConfig = OmegaConf.create(
        {
            # add new arguments
        }
    )

Then, we need to add the following line to the main configuration file config.py.

from .fed.new_algorithm import *

This is the main configuration class in src/appfl/config/config.py. Each algorithm, specified in Config.fed, can be configured in the dataclasses at appfl.config.fed.*.

The main configuration class
 1from dataclasses import dataclass, field
 2from typing import Any
 3from omegaconf import DictConfig, OmegaConf
 4
 5
 6from .fed.federated import *
 7from .fed.iceadmm import *  ## TODO: combine iceadmm and iiadmm under the name of ADMM.
 8from .fed.iiadmm import *
 9
10
11@dataclass
12class Config:
13    fed: Any = Federated()
14
15    # Compute device
16    device: str = "cpu"
17
18    # Number of training epochs
19    num_clients: int = 1
20
21    # Number of training epochs
22    num_epochs: int = 2
23
24    # Number of workers in DataLoader
25    num_workers: int = 0
26
27    # Train data batch info
28    batch_training: bool = True  ## TODO: revisit
29    train_data_batch_size: int = 64
30    train_data_shuffle: bool = False
31
32    # Indication of whether to validate or not using testing data
33    validation: bool = True
34    test_data_batch_size: int = 64
35    test_data_shuffle: bool = False
36
37    # Checking data sanity
38    data_sanity: bool = False
39
40    # Reproducibility
41    reproduce: bool = True
42
43    # PCA on Trajectory
44    pca_dir: str = ""
45    params_start: int=0
46    params_end: int=49
47    ncomponents: int=40
48    
49    # Tensorboard
50    use_tensorboard: bool = False
51
52    # Loading models
53    load_model: bool = False
54    load_model_dirname: str = ""
55    load_model_filename: str = ""
56
57    # Saving models (server)
58    save_model: bool = False
59    save_model_dirname: str = ""
60    save_model_filename: str = ""
61    checkpoints_interval: int = 2
62
63    # Saving state_dict (clients)
64    save_model_state_dict: bool = False
65
66    # Logging and recording outputs
67    output_dirname: str = "output"
68    output_filename: str = "result"
69    
70    logginginfo: DictConfig = OmegaConf.create({})
71    summary_file: str = ""
72
73
74    #
75    # gRPC configutations
76    #
77
78    # 100 MB for gRPC maximum message size
79    max_message_size: int = 104857600
80
81    operator: DictConfig = OmegaConf.create({"id": 1})
82    server: DictConfig = OmegaConf.create(
83        {"id": 1, "host": "localhost", "port": 50051, "use_tls": False, "api_key": None}
84    )
85    client: DictConfig = OmegaConf.create({"id": 1})