谷歌JAX深度学习从零开始学
上QQ阅读APP看书,第一时间看更新

1.2.2 JAX的安装和验证

新安装的WSL需要更新一次,打开WSL终端界面,依次输入如下操作语句:

sudo apt update
sudo apt install gcc make g++
sudo apt install build-essential
sudo apt install python3-pip
pip install --upgrade pip
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple jax== 0.2.19
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple jaxlib== 0.1.70

注意

因为JAX现在仍旧处于调整阶段,可能后面函数会有调整。本书使用的是0.2.19版本的jax和0.1.70版本的jaxlib,读者一定要注意版本的选择。

在需要输入密码的地方直接输入,并且在需要确认的地方输入字符“y”进行确认。

等全部命令运行完毕后,用户可以打开WSL终端运行如下命令:

python3

这是启动WSL自带的Python命令,之后键入如下命令:

import jax.numpy as np
np.add(1.0,1.7)

最终结果如图1.10所示。还可以看到Ubuntu系统上默认安装了Python 3.8.10。

图1.10 运行结果

可以看到最终结果是2.7,并且也提示了本机在运行中只使用CPU而非GPU。对于想使用GPU版本的JAX读者来说,最好的方案是使用纯Ubuntu系统作为开发平台,或者可以升级到Windows 11并安装特定的CUDA驱动程序,这里不再过多阐述,有兴趣的读者可以参考本书附录。