Note
Go to the end to download the full example code.
A FashionMNIST Example
import ray
from ray import tune
from ray.tune.stopper import MaximumIterationStopper
from blades.algorithms.fedavg import FedavgConfig, Fedavg
from fedlib.trainers import TrainerConfig
class ExampleFedavgConfig(FedavgConfig):
def __init__(self, algo_class=None):
"""Initializes a FedavgConfig instance."""
super().__init__(algo_class=algo_class or ExampleFedavg)
self.dataset_config = {
"type": "FashionMNIST",
"num_clients": 4,
"train_batch_size": 32,
}
self.global_model = "cnn"
class ExampleFedavg(Fedavg):
def __init__(self, config=None, logger_creator=None, **kwargs):
super().__init__(config, logger_creator, **kwargs)
@classmethod
def get_default_config(cls) -> TrainerConfig:
return ExampleFedavgConfig()
if __name__ == "__main__":
ray.init()
config_dict = (
ExampleFedavgConfig()
.resources(
num_gpus_for_driver=0.5,
num_cpus_for_driver=1,
num_remote_workers=0,
num_gpus_per_worker=0.5,
)
.to_dict()
)
print(config_dict)
tune.run(
ExampleFedavg,
config=config_dict,
stop=MaximumIterationStopper(100),
)