Source code for appfl.run_serial

from cmath import nan

from collections import OrderedDict
import torch.nn as nn
from torch.optim import *
from torch.utils.data import DataLoader

import numpy as np

from omegaconf import DictConfig

import copy
import time
import logging

from .misc import *
from .algorithm import *


[docs]def run_serial( cfg: DictConfig, model: nn.Module, loss_fn: nn.Module, train_data: Dataset, test_data: Dataset = Dataset(), dataset_name: str = "appfl", ): """Run serial simulation of PPFL. Args: cfg (DictConfig): the configuration for this run model (nn.Module): neural network model to train loss_fn (nn.Module): loss function train_data (Dataset): training data test_data (Dataset): optional testing data. If given, validation will run based on this data. dataset_name (str): optional dataset name """ """ log for a server """ logger = logging.getLogger(__name__) logger = create_custom_logger(logger, cfg) cfg["logginginfo"]["comm_size"] = 1 cfg["logginginfo"]["DataSet_name"] = dataset_name ## Using tensorboard to visualize the test loss if cfg.use_tensorboard: from tensorboardX import SummaryWriter writer = SummaryWriter( comment=cfg.fed.args.optim + "_clients_nums_" + str(cfg.num_clients) ) """ log for clients""" outfile = {} for k in range(cfg.num_clients): output_filename = cfg.output_filename + "_client_%s" % (k) outfile[k] = client_log(cfg.output_dirname, output_filename) """ weight calculation """ total_num_data = 0 for k in range(cfg.num_clients): total_num_data += len(train_data[k]) weights = {} for k in range(cfg.num_clients): weights[k] = len(train_data[k]) / total_num_data "Run validation if test data is given or the configuration is enabled." test_dataloader = None if cfg.validation == True and len(test_data) > 0: test_dataloader = DataLoader( test_data, num_workers=cfg.num_workers, batch_size=cfg.test_data_batch_size, shuffle=cfg.test_data_shuffle, ) else: cfg.validation = False server = eval(cfg.fed.servername)( weights, copy.deepcopy(model), loss_fn, cfg.num_clients, cfg.device, **cfg.fed.args ) server.model.to(cfg.device) batchsize = {} for k in range(cfg.num_clients): batchsize[k] = cfg.train_data_batch_size if cfg.batch_training == False: batchsize[k] = len(train_data[k]) clients = [ eval(cfg.fed.clientname)( k, weights[k], copy.deepcopy(model), loss_fn, DataLoader( train_data[k], num_workers=cfg.num_workers, batch_size=batchsize[k], shuffle=cfg.train_data_shuffle, pin_memory=True, ), cfg, outfile[k], test_dataloader, **cfg.fed.args, ) for k in range(cfg.num_clients) ] start_time = time.time() test_loss = 0.0 test_accuracy = 0.0 best_accuracy = 0.0 for t in range(cfg.num_epochs): per_iter_start = time.time() local_states = [OrderedDict()] global_state = server.model.state_dict() local_update_start = time.time() for k, client in enumerate(clients): ## initial point for a client model client.model.load_state_dict(global_state) ## client update local_states[0][k] = client.update() cfg["logginginfo"]["LocalUpdate_time"] = time.time() - local_update_start global_update_start = time.time() server.update(local_states) cfg["logginginfo"]["GlobalUpdate_time"] = time.time() - global_update_start validation_start = time.time() if cfg.validation == True: test_loss, test_accuracy = validation(server, test_dataloader) if cfg.use_tensorboard: # Add them to tensorboard writer.add_scalar("server_test_accuracy", test_accuracy, t) writer.add_scalar("server_test_loss", test_loss, t) if test_accuracy > best_accuracy: best_accuracy = test_accuracy cfg["logginginfo"]["Validation_time"] = time.time() - validation_start cfg["logginginfo"]["PerIter_time"] = time.time() - per_iter_start cfg["logginginfo"]["Elapsed_time"] = time.time() - start_time cfg["logginginfo"]["test_loss"] = test_loss cfg["logginginfo"]["test_accuracy"] = test_accuracy cfg["logginginfo"]["BestAccuracy"] = best_accuracy server.logging_iteration(cfg, logger, t) """ Saving model """ if (t + 1) % cfg.checkpoints_interval == 0 or t + 1 == cfg.num_epochs: if cfg.save_model == True: save_model_iteration(t + 1, server.model, cfg) server.logging_summary(cfg, logger) for k, client in enumerate(clients): client.outfile.close()