from typing import Callable, Iterator, List
import numpy as np
from torch.utils.data import Dataset, Subset
from .dataset_partitioner import DatasetPartitioner
[docs]class ShardPartitioner(DatasetPartitioner):
"""Partitioner that splits a dataset into shards and assigns shards to
clients."""
def __init__(
self,
num_clients: int,
random_seed: int = 123,
client_id_generator: Callable[[], Iterator] = None,
num_shards: int = 4,
):
super().__init__(num_clients, random_seed, client_id_generator)
assert num_shards >= num_clients, (
"Number of shards cannot be smaller than " "clients"
)
self.num_shards = num_shards
[docs] def split_dataset(self, dataset) -> List[Subset]:
# Sort data by label
targets = dataset.targets
indices = np.argsort(targets)
# Calculate shard sizes
num_indices = len(indices)
base_shard_size = num_indices // self.num_shards
extra = num_indices % self.num_shards
# Initialize shards
shards_indices = []
# Assign indices to shards with extra indices distributed among the first
# few shards
start_idx = 0
for i in range(self.num_shards):
end_idx = start_idx + base_shard_size + (1 if i < extra else 0)
shards_indices.append(indices[start_idx:end_idx])
start_idx = end_idx
# Now we can shuffle the shards to ensure random distribution if required
np.random.shuffle(shards_indices)
# Assign shards to clients evenly
client_data_indices = [[] for _ in range(self.num_clients)]
for shard_indices in shards_indices:
client_idx = np.argmin([len(indices) for indices in client_data_indices])
client_data_indices[client_idx].extend(shard_indices)
# Create client keyconcepts
client_datasets = [
Subset(dataset, subset_indices) for subset_indices in client_data_indices
]
# Check the number of clients
assert (
len(client_datasets) == self.num_clients
), "Number of clients is not equal to expected."
return client_datasets
[docs] def split_datasets(
self, train_dataset: Dataset, test_dataset: Dataset
) -> tuple[list[Subset], list[Subset]]:
# Use the split_dataset method to split both the
# training and testing keyconcepts
train_subsets = self.split_dataset(train_dataset)
test_subsets = self.split_dataset(test_dataset)
return train_subsets, test_subsets