Source code for appfl.misc.data

import torch
from torch.utils import data


[docs]class Dataset(data.Dataset): """This class provides a simple way to define client dataset for supervised learning. This is derived from ``torch.utils.data.Dataset`` so that can be loaded to ``torch.utils.data.DataLoader``. Users may also create their own dataset class derived from this for more data processing steps. An empty ``Dataset`` class is created if no argument is given (i.e., ``Dataset()``). Args: data_input (torch.FloatTensor): optional data inputs data_label (torch.Tensor): optional data ouputs (or labels) """ def __init__( self, data_input: torch.FloatTensor = torch.FloatTensor(), data_label: torch.Tensor = torch.Tensor(), ): self.data_input = data_input self.data_label = data_label
[docs] def __len__(self): """This returns the sample size.""" return len(self.data_label)
[docs] def __getitem__(self, idx): """This returns a sample point for given ``idx``.""" return self.data_input[idx], self.data_label[idx]
# TODO: This is very specific to certain data format. def data_sanity_check(train_datasets, test_dataset, num_channel, num_pixel): ## Check if "DataLoader" from PyTorch works. train_dataloader = data.DataLoader(train_datasets[0], batch_size=64, shuffle=False) for input, label in train_dataloader: assert input.shape[0] == label.shape[0] assert input.shape[1] == num_channel assert input.shape[2] == num_pixel assert input.shape[3] == num_pixel test_dataloader = data.DataLoader(test_dataset, batch_size=64, shuffle=False) for input, label in test_dataloader: assert input.shape[0] == label.shape[0] assert input.shape[1] == num_channel assert input.shape[2] == num_pixel assert input.shape[3] == num_pixel