DeepMind又放福利:开源了一个内部的分布式机器学习库TF-Replicator

网友投稿 252 2022-10-27

DeepMind又放福利:开源了一个内部的分布式机器学习库TF-Replicator

TF-Replicator允许研究人员针对机器学习定位不同的硬件加速器进行,将工作负载扩展到许多设备,并在不同类型的加速器之间无缝切换。

虽然它最初是作为TensorFlow上面的一个库开发的,但目前TF-Replicator的API已经集成到TensorFlow 2.0新的tf.distribute.Strategy中,作为 tf.distribute.Strategy的一部分开源:

接下来,我们将介绍TF-Replicator背后的想法和技术挑战。

构建一个分布式机器学习库

我们开发TF-Replicator的初衷是为DeepMind的研究人员提供一个使用TPU的简单API。TPU为机器学习工作负载提供了可扩展性,实现了许多研究突破,例如使用我们的BigGAN模型实现了最先进的图像合成。

TensorFlow针对TPU的原生API与针对GPU的方式不同,这造成了使用TPU的障碍。TF-Replicator提供了一个更简单、更用户友好的API,隐藏了TensorFlow的TPU API的复杂性。此外,研究平台团队与不同机器学习领域的研究人员密切合作,开发了TF-Replicator API,以确保必要的灵活性和易用性。

TF-Replicator API

使用TF-Replicator编写的代码与使用TensorFlow中为单个设备编写的代码类似,允许用户自由定义自己的模型运行循环。用户只需要定义(1)一个公开数据集的输入函数,以及(2)一个定义其模型逻辑的step函数(例如,梯度下降的单个step):

将计算扩展到多个设备需要设备之间进行通信。在训练机器学习模型的背景下,最常见的通信形式是累积梯度(accumulate gradients)以用于优化算法,如随机梯度下降。

输入数据从主机发送到各个GPU, GPU立即开始处理。当需要在GPU之间交换信息时,它们会在发送数据之前进行同步。

实现

对于多GPU计算,TF-Replicator依赖于“图内复制”(“in-graph replication)模式,其中每个设备的计算在同一个TensorFlow graph中复制。设备之间的通信是通过连接设备对应子图中的节点来实现的。在TF-Replicator中实现这一点很具挑战性,因为在TensorFlow graph中的任何位置都可能发生通信。因此,构造计算的顺序至关重要。

然而,在我们考虑这种方法时,TensorFlow的图形构建API不是线程安全的,这使得在不同线程中同时构建子图非常困难。相反,我们使用图形重写(graph rewriting)在所有设备的子图构建完成后插入通信。在构造子图时,占位符被插入到需要通信的位置。然后,我们跨设备收集所有匹配占位符,并用适当的跨设备计算替换它们。

当TF-Replicator构建一个in-graph replicated计算时,它首先独立地为每个设备构建计算,并将占位符留给用户指定的跨设备计算。构建好所有设备的子图之后,TF-Replicator通过用实际的跨设备计算替换占位符来连接它们。

为AI研究构建一个平台

通过在TF-Replicator的设计和实现过程中与研究人员密切合作,我们最终构建一个库,让用户能够轻松地跨多个硬件加速器进行大规模计算,同时让他们拥有进行前沿AI研究所需的控制和灵活性。

例如,在与研究人员讨论之后,我们添加了MPI风格的通信原语,如all-reduce。TF-Replicator和其他共享基础架构使我们能够在稳健的基础上构建越来越复杂的实验,并在整个DeepMind快速传播最佳实践。

在撰写本文时,TF-Replicator已经成为DeepMind应用最广泛的TPU编程接口。虽然这个库本身并不局限于训练神经网络,但它最常用来训练大量数据。例如,BigGAN模型是在一个512核的TPUv3 pod训练的,batch size为2048。

在采用分布式actor-learner设置的增强学习智能体中,例如我们的重要性加权actor-learner架构,可扩展性是通过让许多actor通过与环境的交互生成新的体验来实现的。然后,learner对这些数据进行处理,以改进agent的策略,表示为一个神经网络。为了应对越来越多的actor,TF-Replicator可以很轻松地将learner分布在多个硬件加速器上。

这些以及更多例子在我们的arXiv论文中有更详细的描述。

Blog:

Paper:

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Kubernetes集群监控Prometheus + Grafana监控方案部署及配置
下一篇:k8s新版本集群搭建从0到1
相关文章

 发表评论

暂时没有评论,来抢沙发吧~