pytorch实现多个模型的weights平均和修改weights

    技术2022-07-10  184

    文章目录

    1. 操作说明2.代码

    1. 操作说明

    有3个结构相同但是weights不同的model组成一个list,models=[model1,model2,model3],还有一个中心模型fl_model,这四个模型的结构和超参数都相同。

    需要进行这样一种操作:平均models里面三个模型的weights,把平均之后的weights"赋值"给fl_model的weights。

    2.代码

    在tensorflow里可以直接用model.get_weights()和model.set_weights()来做,比较直观和方便。感觉pytorch里面稍微复杂一些。进行上述操作的代码如下:

    worker_state_dict=[x.state_dict() for x in models] weight_keys=list(worker_state_dict[0].keys()) fed_state_dict=collections.OrderedDict() for key in weight_keys: key_sum=0 for i in range(len(models)): key_sum=key_sum+worker_state_dict[i][key] fed_state_dict[key]=key_sum/len(models) #### update fed weights to fl model fl_model.load_state_dict(fed_state_dict)
    Processed: 0.024, SQL: 12