谷歌JAX深度学习从零开始学
上QQ阅读APP看书,第一时间看更新

1.1.2 为什么是JAX

JAX是机器学习框架领域的新生力量。JAX从诞生就具有相对于其他深度学习框架更高的高度,并迈出了重要的一步,不是因为它比现有的机器学习框架具有更简洁的API,或者因为它比TensorFlow和PyTorch在被设计的事情上做得更好,而是因为它允许我们更容易地尝试更广阔的思想空间。

JAX把看不到的细节藏在底层内部结构中,而无须关心其使用过程和细节,很明显,JAX关心的是如何让开发者做出创造性的工作。JAX对如何使用做了很少的假设,具有很好的灵活性。

JAX目前已经达到深度学习框架的最高水平。在当前开源的框架中,没有哪一个框架能够在简洁、易用、速度这3个方面有两个能同时超过JAX。

● 简洁:JAX的设计追求最少的封装,尽量避免重复造轮子。不像TensorFlow中充斥着graph、operation、name_scope、variable、tensor、layer等全新的概念,JAX的设计遵循tensor→variable(autograd)→Module 3个由低到高的抽象层次,分别代表高维数组(张量)、自动求导(变量)和神经网络(层/模块),而且这3个抽象之间联系紧密,可以同时进行修改和操作。简洁的设计带来的另外一个好处就是代码易于理解。JAX的源码只有TensorFlow的十分之一左右,更少的抽象、更直观的设计使得JAX的源码十分易于阅读。

● 速度:JAX的灵活性不以速度为代价,在许多评测中,JAX的速度表现胜过TensorFlow和PyTorch等框架。框架的运行速度和程序员的编码水平存在着一定关系,但同样的算法,使用JAX实现可能快过用其他框架实现。

● 易用:JAX是所有的框架中面向对象设计得最优雅的一个。JAX的设计最符合人们的思维,它让用户尽可能地专注于实现自己的想法,即所思即所得,不需要考虑太多关于框架本身的束缚。

JAX的设计体现了Linux设计哲学—do one thing and do it well。JAX很轻量级,专注于高效的数值计算,由于提供了调用其他框架的功能,这样JAX程序的编写以及数据的加载可以使用其他框架的现成工具。并且Google也在基于JAX构建生态:包括神经网络库Haiku、梯度处理和优化的库Optax、强化学习库Rlax,以及用来帮助编写可靠代码的chex工具库。很多Google的研究组利用JAX来开发训练神经网络的工具库,比如Flax、Trax。