This repository contains the official PyTorch code for the paper: Federated Learning via Meta-Variational Dropout published in NeurIPS 2023.
- Python >= 3.7.4
- CUDA >= 10.0 supported GPU
- Anaconda
Setup Environment
pip install -r environment.yml
conda activate metavdpython main.py --model <model-name> --dataset <dataset-name> <other-options>EX) Run Cifar10 Experiment with MetaVD
python main.py --model nvdpgaus --dataset cifar10EX) Run Cifar100 Experiment with MetaVD and Heterogeneity level of
python main.py --model nvdpgaus --dataset cifar100 --alpha 5.0We currently support following models and datasets options.
| Model Name | Flag | Description |
|---|---|---|
| FedAvg | fedavg |
Federated Averaging |
| FedAvg + Finetuning | fedavgper |
Personalized Federated Learning |
| FedAvg + MetaVD | fedavgnvdpgausq |
Federated Averaging with MetaVD (proposed in this work) |
| FedAvg + SNIP | fedavgsnip |
Federated Averaging with SNIP |
| FedProx | fedprox |
Federated Proximal Optimization |
| FedBE | fedbe |
Federated Learning with Bayesian Ensemble |
| Reptile | reptile |
Federated Learning with Reptile |
| Reptile + VD | vdgausq |
Reptile with VD |
| Reptile + EnsembleVD | vdgausemq |
Reptile with EnsembleVD |
| Reptile + MetaVD | nvdpgausq |
Reptile with MetaVD (proposed in this work) |
| Reptile + SNIP | reptilesnip |
Reptile with SNIP |
| MAML | maml |
Federated Learning with Model-Agnostic Meta-Learning |
| MAML + MetaVD | mamlgausq |
MAML with MetaVD (proposed in this work) |
| MAML + SNIP | mamlsnip |
MAML with SNIP |
| PerFedAvg | perfedavg |
HF-MAML with SNIP |
| PerFedAvg + MetaVD | perfedavgnvdpgausq |
HF-MAML with MetaVD (proposed in this work) |
| PerFedAvg + SNIP | perfedavgsnip |
HF-MAML with SNIP |
| Dataset Name | Flag | Description |
|---|---|---|
| Femnist | femnist |
Federated EMNIST dataset |
| Celeba | celeba |
CelebA dataset |
| MNIST | mnist |
MNIST dataset |
| Cifar10 | cifar10 |
CIFAR10 dataset |
| Cifar100 | cifar100 |
CIFAR100 dataset |
| EMNIST | emnist |
Extended MNIST dataset |
| FMNIST | fmnist |
Fashion MNIST dataset |
Please see the arg parser in main.py file to enable other options.
For all datasets, we set the number of rounds (num_rounds) to 1000 to ensure sufficient convergence following conventions. The batch size (local_bs) was set to 64, and local steps (local_epochs) was set to 5. Personalization was executed with a batch size (adaptation_bs) of 64 and a 1-step update.
For all methods, we investigated the server learning rate and local SGD learning rate within identical
ranges. The server learning rate η (server_lr) was explored within the range of [0.6, 0.7, 0.8, 0.9, 1.0]. The local
SGD learning rate (inner_lr) was investigated within the range of [0.005, 0.01, 0.015, 0.02, 0.025, 0.03].For MetaVD, an additional KL divergence weight parameter
β (beta) is needed, and we set its optimal value to 10.
-
Tensorboard Setup
cd runs tensorboard --logdir=./ --port=7770 --samples_per_plugin image=100 --reload_multifile=True --reload_interval 30 --host=0.0.0.0Access visualizations at localhost:7770.
If you find this work useful, please cite our paper:
@article{jeon2024federated,
title={Federated Learning via Meta-Variational Dropout},
author={Jeon, Insu and Hong, Minui and Yun, Junhyeog and Kim, Gunhee},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}Thank you, my colleagues, for your valuable contributions.