Header: nn

神经网络框架调研:PyTorch、TensorFlow、JAX

2024-12-22

从安装到数组运算、自动微分、网络搭建等多个方面,分析 PyTorch、TensorFlow、JAX 这三个神经网络框架的特色和优劣。

 

概览 #

在目前的神经网络领域,主流的框架基本就是 Facebook 的 PyTorch,Google 的 TensorFlow、JAX 这三个。其他的框架已经停止更新的 Theano,或者一些小众的框架比如 Amazon 的 MXNet,百度的 PaddlePaddle,华为的 MindSpore 等等,基本就是在他们公司内部及相关联的生态中使用,故在此不做过多介绍。

PyTorch 由 Facebook AI Research Lab (FAIR) 于 2016 年发布,强调易用性和灵活性,广泛应用于学术实验和模型原型设计,已经逐渐成为科研界和业界流行的深度学习框架。

TensorFlow 由 Google Brain 团队开发,于 2015 年开源。不过众所周知,TensorFlow 1.0 的静态图 API 相当反人类,所以 2019 年 TensorFlow 2.0 版本的发布算是它的新生。

JAX 由 Google 于 2018 年推出,具有与 NumPy 高度相似的 API,以及极高的灵活性和强大的自动微分功能。它的底层和 TensorFlow 一样使用了 XLA (Accelerated Linear Algebra)。近年来,JAX 在深度学习领域,尤其是在涉及到大量数学计算的问题上,受到越来越多关注。

框架的安装 #

在配置好 CUDA 环境或者仅考虑 CPU 版本的情况下,这三个框架的安装都已经做到了足够的简单。这里简单做一个概述,但建议在实际安装的时候参考官网给出的说明。

PyTorch #

PyTorch 的安装参考 https://pytorch.org/get-started/locally/, 需要注意的一点是 CPU 版本需要指定 PyTorch 自己的网址。

pip3 install torch --index-url https://download.pytorch.org/whl/cpu  # CPU
pip3 install torch  # GPU
sh

其 CPU 版本的安装包大小约 170 MB。

TensorFlow #

TensorFlow 的安装参考官网,需要使用合适的方式访问 https://www.tensorflow.org/install/pip

pip install tensorflow_cpu  # CPU
pip install 'tensorflow[and-cuda]'  # GPU
sh

其 CPU 版本的安装包大小约 220 MB。

JAX #

pip install jax  # CPU
pip install "jax[cuda12]"  # GPU
sh

Import #

import torch
python
import tensorflow as tf
python
from jax import numpy as jnp
import jax
python

功能和 API 对比 #

一个神经网络框架最基础的功能莫过于在 CPU 或 GPU 上进行多维矩阵运算和自动微分了。利用这两个功能可以手搓一个神经网络,但为了易用性,神经网络框架往往还内置了一些常用的网络结构,包括全连接层、卷积层等等。除此之外,监督式学习往往还涉及到了数据的预处理,最好有一些函数能够方便地读取数据、给数据分组。接下来,我们就从这些方面来看看这三个框架。

数组运算 #

PyTorch #

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
c = torch.matmul(a, b)
c
python
tensor([[19, 22],
        [43, 50]])
text
torch.log(c)
python
tensor([[2.9444, 3.0910],
        [3.7612, 3.9120]])
text
torch.svd(c.to(torch.float32))
python
torch.return_types.svd(
U=tensor([[-0.4033, -0.9150],
        [-0.9150,  0.4033]]),
S=tensor([7.2069e+01, 5.5507e-02]),
V=tensor([[-0.6523, -0.7580],
        [-0.7580,  0.6523]]))
text

TensorFlow #

a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 6], [7, 8]])
c = tf.matmul(a, b)
c
python
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[19, 22],
       [43, 50]], dtype=int32)>
text
tf.math.log(tf.cast(c, tf.float32))
python
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2.944439 , 3.0910425],
       [3.7612002, 3.912023 ]], dtype=float32)>
text
tf.linalg.svd(tf.cast(c, tf.float32))
python
(<tf.Tensor: shape=(2,), dtype=float32, numpy=array([7.206939e+01, 5.550127e-02], dtype=float32)>,
 <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
 array([[ 0.40334514,  0.9150479 ],
        [ 0.9150479 , -0.40334514]], dtype=float32)>,
 <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
 array([[ 0.6522966 ,  0.75796384],
        [ 0.75796384, -0.6522966 ]], dtype=float32)>)
text

JAX #

a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
c = jnp.matmul(a, b)
c
python
Array([[19, 22],
       [43, 50]], dtype=int32)
text
jnp.log(c)
python
Array([[2.944439 , 3.0910425],
       [3.7612002, 3.912023 ]], dtype=float32)
text
jnp.linalg.svd(c)
python
SVDResult(U=Array([[-0.40334493, -0.915048  ],
       [-0.91504794,  0.40334517]], dtype=float32), S=Array([7.206938e+01, 5.550685e-02], dtype=float32), Vh=Array([[-0.6522967, -0.7579638],
       [-0.7579638,  0.6522967]], dtype=float32))
text

自动微分 #

PyTorch #

只需要对输出做 backward, 就能得到它对输入的导数

x = torch.tensor(3.0, requires_grad=True)
y = x**2
y.backward()
print(x.grad)
python
tensor(6.)
text

TensorFlow #

可以对 GradientTape 这一上下文中进行的操作做微分

x = tf.Variable(3.0)
with tf.GradientTape() as tape:
    y = x**2
tape.gradient(y, x)
python
<tf.Tensor: shape=(), dtype=float32, numpy=6.0>
text

JAX #

JAX 更加函数化。你需要通过 jax.grad 得到函数 ff 的导函数 ff',再计算导函数 f(x)f'(x) 即可。

def loss_fn(x):
    return x**2
grad_fn = jax.grad(loss_fn)
grad_fn(3.0)
python
Array(6., dtype=float32, weak_type=True)
text

JAX 还支持更加灵活的导数,包括高阶导数、Hessian 矩阵,Jacobian-vector 乘积(jax.jvp,前向传播)和 vector-Jacobian 乘积(jax.vjp,反向传播)等等,在此不一一列举,感兴趣可查阅 JAX Autodiff Cookbook

常用网络结构 #

PyTorch #

torch.nn 模块内置各类层和神经网络组件。指定每一层的输入输出维度和激活函数即可。

import torch.nn as nn

model = nn.Sequential(
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.Softmax(dim=-1)
)
model
python
Sequential(
  (0): Linear(in_features=64, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=10, bias=True)
  (3): Softmax(dim=-1)
)
text
model(torch.randn(1, 64))
python
tensor([[0.1130, 0.1304, 0.0751, 0.0859, 0.0861, 0.0781, 0.0827, 0.1015, 0.1505,
         0.0967]], grad_fn=<SoftmaxBackward0>)
text

TensorFlow #

TensorFlow 2.0 引入了 Keras 库作为其高级 API。无需手动指定每一层的输入维度,但也因此多了一个 build 的步骤。

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])
model.layers
python
[<Dense name=dense, built=False>, <Dense name=dense_1, built=False>]
text
model.build(input_shape=(None, 64))
model(tf.random.normal(shape=(1, 64)))
python
<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.05575525, 0.0179697 , 0.07684414, 0.06185716, 0.19958737,
        0.14854056, 0.06055942, 0.06959403, 0.15985522, 0.14943717]],
      dtype=float32)>
text

JAX #

本体不直接提供高层神经网络组件,需要使用第三方库例如 Flax

from flax import linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return nn.softmax(x)

model = MLP()
model
python
MLP()
text

与 PyTorch 和 TensorFlow 不同,JAX 的随机数 key 必须显式地作为参数传入,不存在一个全局的随机数状态。另外,Flax 中模型的参数也是作为参数传入的,模型本身不存储参数信息。这和 JAX 函数式编程的思想是高度一致的。

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1, 64))
params = model.init(key, x)
model.apply(params, x)
python
Array([[0.057403  , 0.07946321, 0.06227293, 0.06888593, 0.04725461,
        0.12753344, 0.15180823, 0.09623848, 0.1865097 , 0.12263051]],      dtype=float32)
text

数据处理工具 #

PyTorch #

PyTorch 的数据加载和处理模块设计较为简单和直接,提供 torch.utils.data,内置 DatasetDataLoader 可以实现高效的数据预处理。

from torch.utils.data import DataLoader, TensorDataset


x = torch.tensor([[1], [2], [3]])
y = torch.tensor([[4], [5], [6]])
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
    print(batch)
python
[tensor([[1],
        [2]]), tensor([[4],
        [5]])]
[tensor([[3]]), tensor([[6]])]
text

TensorFlow #

提供 tf.data API,设计较为复杂,但功能强大。提供了丰富的 API 和高级功能,如数据流水线(pipeline)、并行处理、缓存等。

data = tf.data.Dataset.from_tensor_slices([i for i in range(10)])
data = data.shuffle(buffer_size=10).batch(2).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

for batch in data:
    print(batch)
python
tf.Tensor([4 6], shape=(2,), dtype=int32)
tf.Tensor([8 1], shape=(2,), dtype=int32)
tf.Tensor([0 5], shape=(2,), dtype=int32)
tf.Tensor([7 3], shape=(2,), dtype=int32)
tf.Tensor([9 2], shape=(2,), dtype=int32)
text

JAX #

JAX 没有自带的数据处理工具,但可以安装 tensorflow_datasets 库,从而获得和 TensorFlow 类似的数据处理体验。

其他生态 #

TensorFlow #

TensorFlow 在工业化部署上有较大优势,提供了包括 TensorFlow Lite(移动端), TensorFlow.js(浏览器端)等的多种部署方案。

此外 TensorFlow Extended (TFX) 提供了完整的端到端生产工作流,TensorBoard 提供了丰富的训练可视化工具。

PyTorch #

PyTorch 对图像(TorchVision)、音频(TorchAudio)、自然语言(TorchText)都有着强大的支持。

此外 TorchServe 也提供了大量模型的服务端部署功能。

JAX #

JAX 依然是一个比较新的框架,在生态方面处于一定的劣势。但是近年来,越来越多的第三方库除了对 TensorFlow 和 PyTorch 的支持以外,也加入了 JAX 的支持,包括 Keras, Hugging Face 的 Transformers 和 Diffusers 等等。

我该怎么选? #

另外,Keras? #

Keras 严格意义上不是一个独立的神经网络框架,它是依赖于 JAX、PyTorch 或 TensorFlow 作为后端,而它自己扮演的是高级封装的角色。Keras 库最早由谷歌的程序员 Francois Chollet 在 2015 年开发,TensorFlow 问世之后从上古框架 Theano 换到了 TensorFlow。TensorFlow 2.0 出来之后被正式“收编”为 TensorFlow 的官方高级 API。不过后来 Keras 后来又从 TensorFlow 的代码库分离了出来,重新成为了一个独立项目。自从 2023 年 Keras 3.0 发布以来,Keras 正式支持了 JAX、PyTorch 和 TensorFlow 作为后端。

所以说 Keras 从定位上讲更像 Flax,我们这里限于篇幅就不展开介绍 Keras 了。

(本文相关代码可以在 AllanChain/nn-frameworks 找到)

👍
1
Leave your comments and reactions on GitHub