Source code for appfl.run_mpi

from cmath import nan

from collections import OrderedDict
import torch.nn as nn
from torch.optim import *
from torch.utils.data import DataLoader

import numpy as np

from omegaconf import DictConfig

import copy
import time
import logging

from .misc import *
from .algorithm import *

from mpi4py import MPI


[docs]def run_server( cfg: DictConfig, comm: MPI.Comm, model: nn.Module, loss_fn: nn.Module, num_clients: int, test_dataset: Dataset = Dataset(), dataset_name: str = "appfl", ): """Run PPFL simulation server that aggregates and updates the global parameters of model Args: cfg (DictConfig): the configuration for this run comm: MPI communicator model (nn.Module): neural network model to train loss_fn (nn.Module): loss function num_clients (int): the number of clients used in PPFL simulation test_data (Dataset): optional testing data. If given, validation will run based on this data. DataSet_name (str): optional dataset name """ ## Start comm_size = comm.Get_size() comm_rank = comm.Get_rank() num_client_groups = np.array_split(range(num_clients), comm_size - 1) # FIXME: I think it's ok for server to use cpu only. device = "cpu" """ log for a server """ logger = logging.getLogger(__name__) logger = create_custom_logger(logger, cfg) cfg["logginginfo"]["comm_size"] = comm_size cfg["logginginfo"]["DataSet_name"] = dataset_name ## Using tensorboard to visualize the test loss if cfg.use_tensorboard: from tensorboardX import SummaryWriter writer = SummaryWriter( comment=cfg.fed.args.optim + "_clients_nums_" + str(cfg.num_clients) ) "Run validation if test data is given or the configuration is enabled." if cfg.validation == True and len(test_dataset) > 0: test_dataloader = DataLoader( test_dataset, num_workers=cfg.num_workers, batch_size=cfg.test_data_batch_size, shuffle=cfg.test_data_shuffle, ) else: cfg.validation = False """ Receive the number of data from clients Compute "weight[client] = data[client]/total_num_data" from a server Scatter "weight information" to clients """ num_data = comm.gather(0, root=0) total_num_data = 0 for rank in range(1, comm_size): for val in num_data[rank].values(): total_num_data += val weight = [] weights = {} for rank in range(comm_size): if rank == 0: weight.append(0) else: temp = {} for key in num_data[rank].keys(): temp[key] = num_data[rank][key] / total_num_data weights[key] = temp[key] weight.append(temp) weight = comm.scatter(weight, root=0) # TODO: do we want to use root as a client? server = eval(cfg.fed.servername)( weights, copy.deepcopy(model), loss_fn, num_clients, device, **cfg.fed.args ) do_continue = True start_time = time.time() test_loss = 0.0 test_accuracy = 0.0 best_accuracy = 0.0 for t in range(cfg.num_epochs): per_iter_start = time.time() do_continue = comm.bcast(do_continue, root=0) # We need to load the model on cpu, before communicating. # Otherwise, out-of-memeory error from GPU server.model.to("cpu") global_state = server.model.state_dict() local_update_start = time.time() global_state = comm.bcast(global_state, root=0) local_states = comm.gather(None, root=0) cfg["logginginfo"]["LocalUpdate_time"] = time.time() - local_update_start global_update_start = time.time() server.update(local_states) cfg["logginginfo"]["GlobalUpdate_time"] = time.time() - global_update_start validation_start = time.time() best_accuracy = 0 if cfg.validation == True: test_loss, test_accuracy = validation(server, test_dataloader) if cfg.use_tensorboard: # Add them to tensorboard writer.add_scalar("server_test_accuracy", test_accuracy, t) writer.add_scalar("server_test_loss", test_loss, t) if test_accuracy > best_accuracy: best_accuracy = test_accuracy cfg["logginginfo"]["Validation_time"] = time.time() - validation_start cfg["logginginfo"]["PerIter_time"] = time.time() - per_iter_start cfg["logginginfo"]["Elapsed_time"] = time.time() - start_time cfg["logginginfo"]["test_loss"] = test_loss cfg["logginginfo"]["test_accuracy"] = test_accuracy cfg["logginginfo"]["BestAccuracy"] = best_accuracy server.logging_iteration(cfg, logger, t) """ Saving model """ if (t + 1) % cfg.checkpoints_interval == 0 or t + 1 == cfg.num_epochs: if cfg.save_model == True: save_model_iteration(t + 1, server.model, cfg) if np.isnan(test_loss) == True: break """ Summary """ server.logging_summary(cfg, logger) do_continue = False do_continue = comm.bcast(do_continue, root=0)
[docs]def run_client( cfg: DictConfig, comm: MPI.Comm, model: nn.Module, loss_fn: nn.Module, num_clients: int, train_data: Dataset, test_data: Dataset = Dataset(), ): """Run PPFL simulation clients, each of which updates its own local parameters of model Args: cfg (DictConfig): the configuration for this run comm: MPI communicator model (nn.Module): neural network model to train num_clients (int): the number of clients used in PPFL simulation train_data (Dataset): training data test_data (Dataset): testing data """ comm_size = comm.Get_size() comm_rank = comm.Get_rank() ## We assume to have as many GPUs as the number of MPI processes. if cfg.device == "cuda": device = f"cuda:{comm_rank-1}" else: device = cfg.device num_client_groups = np.array_split(range(num_clients), comm_size - 1) """ log for clients""" outfile = {} for _, cid in enumerate(num_client_groups[comm_rank - 1]): output_filename = cfg.output_filename + "_client_%s" % (cid) outfile[cid] = client_log(cfg.output_dirname, output_filename) """ Send the number of data to a server Receive "weight_info" from a server (fedavg) "weight_info" is not needed as of now. (iceadmm+iiadmm) "weight_info" is needed for constructing coefficients of the loss_function """ num_data = {} for _, cid in enumerate(num_client_groups[comm_rank - 1]): num_data[cid] = len(train_data[cid]) comm.gather(num_data, root=0) weight = None weight = comm.scatter(weight, root=0) batchsize = {} for _, cid in enumerate(num_client_groups[comm_rank - 1]): batchsize[cid] = cfg.train_data_batch_size if cfg.batch_training == False: batchsize[cid] = len(train_data[cid]) "Run validation if test data is given or the configuration is enabled." if cfg.validation == True and len(test_data) > 0: test_dataloader = DataLoader( test_data, num_workers=cfg.num_workers, batch_size=cfg.test_data_batch_size, shuffle=cfg.test_data_shuffle, ) else: cfg.validation = False test_dataloader = None clients = [ eval(cfg.fed.clientname)( cid, weight[cid], copy.deepcopy(model), loss_fn, DataLoader( train_data[cid], num_workers=cfg.num_workers, batch_size=batchsize[cid], shuffle=cfg.train_data_shuffle, pin_memory=True, ), cfg, outfile[cid], test_dataloader, **cfg.fed.args, ) for _, cid in enumerate(num_client_groups[comm_rank - 1]) ] do_continue = comm.bcast(None, root=0) local_states = OrderedDict() while do_continue: """Receive "global_state" """ global_state = comm.bcast(None, root=0) """ Update "local_states" based on "global_state" """ for client in clients: cid = client.id ## initial point for a client model client.model.load_state_dict(global_state) ## client update local_states[cid] = client.update() """ Send "local_states" to a server """ comm.gather(local_states, root=0) do_continue = comm.bcast(None, root=0) for client in clients: client.outfile.close()