from typing import List
import torch
def _compute_scores(distances, i, n, f):
"""Compute scores for node i.
Args:
distances {dict} -- A dict of dict of distance. distances[i][j] = dist.
i, j starts with 0.
i {int} -- index of worker, starting from 0.
n {int} -- total number of workers
f {int} -- Total number of Byzantine workers.
Returns:
float -- krum distance score of i.
"""
s = [distances[j][i] ** 2 for j in range(i)] + [
distances[i][j] ** 2 for j in range(i + 1, n)
]
_s = sorted(s)[: n - f - 2]
return sum(_s)
def _multi_krum(distances, n, f, m):
"""Multi_Krum algorithm.
Arguments:
distances {dict} -- A dict of dict of distance. distances[i][j] = dist.
i, j starts with 0.
n {int} -- Total number of workers.
f {int} -- Total number of Byzantine workers.
m {int} -- Number of workers for aggregation.
Returns:
list -- A list indices of worker indices for aggregation. length <= m
"""
if n < 1:
raise ValueError(
"Number of workers should be positive integer. Got {}.".format(f)
)
if m < 1 or m > n:
raise ValueError(
"Number of workers for aggregation should be >=1 and <= {}. Got {}.".format(
m, n
)
)
if 2 * f + 2 > n:
raise ValueError("Too many Byzantine workers: 2 * {} + 2 >= {}.".format(f, n))
for i in range(n - 1):
for j in range(i + 1, n):
if distances[i][j] < 0:
raise ValueError(
"The distance between node {} and {} should be non-negative: "
"Got {}.".format(i, j, distances[i][j])
)
scores = [(i, _compute_scores(distances, i, n, f)) for i in range(n)]
sorted_scores = sorted(scores, key=lambda x: x[1])
return list(map(lambda x: x[0], sorted_scores))[:m]
def _compute_euclidean_distance(v1, v2):
return (v1 - v2).norm()
def _pairwise_euclidean_distances(vectors):
"""Compute the pairwise euclidean distance.
Arguments:
vectors {list} -- A list of vectors.
Returns:
dict -- A dict of dict of distances {i:{j:distance}}
"""
n = len(vectors)
vectors = [v.flatten() for v in vectors]
distances = {}
for i in range(n - 1):
distances[i] = {}
for j in range(i + 1, n):
distances[i][j] = _compute_euclidean_distance(vectors[i], vectors[j]) ** 2
return distances
[docs]class Multikrum(object):
r"""A robust aggregator from paper `"Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent" <https://papers.nips.cc/paper_files/paper/2017/hash/f4b9ec30ad9f68f89b29639786cb62ef-Abstract.html>`_."""
def __init__(self, num_byzantine: int, k=1):
self.f = num_byzantine
self.m = k
def __call__(self, inputs: List[torch.Tensor]):
updates = torch.stack(inputs, dim=0)
distances = _pairwise_euclidean_distances(updates)
top_m_indices = _multi_krum(distances, len(updates), self.f, self.m)
values = torch.stack([updates[i] for i in top_m_indices], dim=0).mean(dim=0)
return values