Source code for fedlib.datasets.partitioners.dirichlet_partitioner

from typing import List, Callable, Iterator, Any

import torch
from torch.utils.data import Dataset, Subset

from .dataset_partitioner import DatasetPartitioner


[docs]class DirichletPartitioner(DatasetPartitioner): """Partitioner that uses Dirichlet distribution to allocate samples to clients.""" def __init__( self, num_clients: int = 4, random_seed: int = 123, client_id_generator: Callable[[], Iterator] = None, alpha: float = 1.0, same_proportions: bool = True, ): super().__init__(num_clients, random_seed, client_id_generator) self.alpha = alpha self.same_proportions = same_proportions
[docs] def split_dataset(self, dataset: Dataset) -> List[Subset]: subsets, _ = self._split_single_dataset(dataset) return subsets
[docs] def split_datasets( self, train_dataset: Dataset, test_dataset: Dataset ) -> tuple[list[Subset[Any]], list[Subset[Any]]]: # Split train dataset and get the class proportions train_subsets, class_proportions = self._split_single_dataset(train_dataset) # Optionally, use the same class proportions for test dataset if self.same_proportions: test_subsets, _ = self._split_single_dataset( test_dataset, class_proportions ) else: test_subsets, _ = self._split_single_dataset(test_dataset) return train_subsets, test_subsets
def _split_single_dataset(self, dataset: Dataset, class_proportions=None): targets = torch.tensor(dataset.targets) self._validate_dataset(targets) num_classes = len(torch.unique(targets)) class_indices = self._get_class_indices(targets, num_classes) # Sample from Dirichlet distribution if class_proportions are not provided if class_proportions is None: class_proportions = torch.distributions.Dirichlet( torch.ones(self.num_clients) * self.alpha ).sample((num_classes,)) client_indices = self._allocate_samples_to_clients( class_indices, class_proportions ) return [ Subset(dataset, indices) for indices in client_indices ], class_proportions def _validate_dataset(self, targets: torch.Tensor): assert ( len(targets) >= self.num_clients ), "Not enough samples to distribute to each client" @staticmethod def _get_class_indices( targets: torch.Tensor, num_classes: int ) -> List[torch.Tensor]: return [torch.where(targets == i)[0] for i in range(num_classes)] def _allocate_samples_to_clients( self, class_indices: List[torch.Tensor], class_proportions: torch.Tensor ) -> List[List[int]]: # Sample from Dirichlet distribution for class proportions client_indices = [[] for _ in range(self.num_clients)] # Allocate samples to clients class by class for c, indices in enumerate(class_indices): shuffled_indices = indices[torch.randperm(len(indices))] allocations = self._calculate_allocations( shuffled_indices, class_proportions[c] ) # Assign indices to clients based on allocations prev_proportion = 0 for client_idx, allocation in enumerate(allocations): client_indices[client_idx].extend( shuffled_indices[ prev_proportion : prev_proportion + allocation ].tolist() ) prev_proportion += allocation # Ensure at least one sample per client return self._ensure_minimum_allocation_per_client(client_indices) def _calculate_allocations( self, indices: torch.Tensor, proportions: torch.Tensor ) -> torch.Tensor: allocations = (proportions * len(indices)).round().int() while allocations.sum() < len(indices): allocations[torch.argmin(allocations)] += 1 return allocations def _ensure_minimum_allocation_per_client( self, client_indices: List[List[int]] ) -> List[List[int]]: # Check and redistribute one sample to any empty client list max_length, max_subset_idx = max( (len(indices), idx) for idx, indices in enumerate(client_indices) ) for i in range(self.num_clients): if len(client_indices[i]) == 0: if max_length > 1: client_indices[i].append(client_indices[max_subset_idx].pop()) max_length -= 1 else: raise ValueError( "Cannot ensure minimum allocation for each client." ) return client_indices