实现横向联邦学习
.jpg)
实现横向联邦学习
Alive~o.0实现横向联邦学习
横向联邦学习的服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合。
首先定义一个服务端类Server,类中的主要函数包括以下几个。
-
定义构造函数。
在构造函数中,服务端的工作包括:
第一,将配置信息拷贝到服务端中;
第二,按照配置中的模型信息获取模型,使用torchvision 的models模块内置的ResNet-18模型。
-
定义模型聚合函数。
服务端的主要功能是进行模型的聚合,通过接收客户端上传的模型,使用聚合函数更新全局模型。采用经典的FedAvg 算法。
1
2
3
4
5
6
7def model_aggregate(self, weight_accumulator):
for name, data in self.global_model.state_dict().items():
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
if data.type() != update_per_layer.type():
data.add_(update_per_layer.to(torch.int64))
else:
data.add_(update_per_layer) -
定义模型评估函数。
评估当前的全局模型性能。
客户端
横向联邦学习的客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。定义客户端类Client,类中的主要函数包括以下两种。
-
定义构造函数。
在客户端构造函数中,客户端的主要工作包括
1.将配置信息拷贝到客户端中;
2.按照配置中的模型信息获取模型,通常由服务端将模型参数传递给客户端,客户端将该全局模型覆盖掉本地模型;
3.配置本地训练数据,通过torchvision 的datasets 模块获取cifar10 数据集后按客户端ID切分,不同的客户端拥有不同的子数据集,相互之间没有交集。
-
定义模型本地训练函数。
这是一个图像分类的任务,使用交叉熵作为本地模型的损失函数,利用梯度下降来求解并更新参数值,实现细节如下面代码块所示。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23def local_train(self, model):
for name, param in model.state_dict().items():
self.local_model.state_dict()[name].copy_(param.clone())
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],
momentum=self.conf['momentum'])
self.local_model.train()
for e in range(self.conf["local_epochs"]):
for batch_id, batch in enumerate(self.train_loader):
data, target = batch
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
output = self.local_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
print("Epoch %d done." % e)
diff = dict()
for name, data in self.local_model.state_dict().items():
diff[name] = (data - model.state_dict()[name])
return diff
整合
当配置文件、服务端类和客户端类都定义完毕,我们将这些信息组合起来。
首先,读取配置文件信息。
1 | with open(args.conf, 'r') as f: |
分别定义一个服务端对象和多个客户端对象,用来模拟横向联邦训练场景。
1 | train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"]) |
每一轮的迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练,被选中的客户端调用本地训练接口local_train进行本地训练,最后服务端调用模型聚合函数model_aggregate来更新全局模型。
配置信息
- model_name:模型名称
- no_models:客户端数量
- type:数据集信息
- global_epochs:全局迭代次数,即服务端与客户端的通信迭代次数
- local_epochs:本地模型训练迭代次数
- k:每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练。
- batch_size:本地训练每一轮的样本数
- lr,momentum,lambda:本地训练的超参数设置
联邦学习与中心化训练的效果对比
- 联邦训练配置:一共10台客户端设备(no_models=10),每一轮任意挑选其中的5台参与训练(k=5), 每一次本地训练迭代次数为3次(local_epochs=3),全局迭代次数为20次(global_epochs=20)。
- 集中式训练配置:我们不需要单独编写集中式训练代码,只需要修改联邦学习配置既可使其等价于集中式训练。具体来说,我们将客户端设备no_models和每一轮挑选的参与训练设备数k都设为1即可。这样只有1台设备参与的联邦训练等价于集中式训练。其余参数配置信息与联邦学习训练一致。图中我们将局部迭代次数分别设置了1,2,3来进行比较。
3.7 联邦学习在模型推断上的效果对比
单点训练只的是在某一个客户端下,利用本地的数据进行模型训练的结果。
我们看到单点训练的模型效果(蓝色条)明显要低于联邦训练 的效果(绿色条和红色条),这也说明了仅仅通过单个客户端的数据,不能够 很好的学习到数据的全局分布特性,模型的泛化能力较差。
此外,对于每一轮 参与联邦训练的客户端数目(k 值)不同,其性能也会有一定的差别,k 值越大,每一轮参与训练的客户端数目越多,其性能也会越好,但每一轮的完成时间也会相对较长。
运行原作者代码结果如下:
尝试增大local_epochs,lr,batch_size,单独增加batch_size训练速度会有所提升,发现总体速率和原始保持一致,效果有明显提升:
参考文献及代码
2.Communication-Efficient Learning of Deep Networks from Decentralized Data
3.Communication-Efficient Learning of Deep Networks from Decentralized Data(代码)