概览 #
在目前的神经网络领域,主流的框架基本就是 Facebook 的 PyTorch,Google 的 TensorFlow、JAX 这三个。其他的框架已经停止更新的 Theano,或者一些小众的框架比如 Amazon 的 MXNet,百度的 PaddlePaddle,华为的 MindSpore 等等,基本就是在他们公司内部及相关联的生态中使用,故在此不做过多介绍。
PyTorch 由 Facebook AI Research Lab (FAIR) 于 2016 年发布,强调易用性和灵活性,广泛应用于学术实验和模型原型设计,已经逐渐成为科研界和业界流行的深度学习框架。
TensorFlow 由 Google Brain 团队开发,于 2015 年开源。不过众所周知,TensorFlow 1.0 的静态图 API 相当反人类,所以 2019 年 TensorFlow 2.0 版本的发布算是它的新生。
JAX 由 Google 于 2018 年推出,具有与 NumPy 高度相似的 API,以及极高的灵活性和强大的自动微分功能。它的底层和 TensorFlow 一样使用了 XLA (Accelerated Linear Algebra)。近年来,JAX 在深度学习领域,尤其是在涉及到大量数学计算的问题上,受到越来越多关注。
框架的安装 #
在配置好 CUDA 环境或者仅考虑 CPU 版本的情况下,这三个框架的安装都已经做到了足够的简单。这里简单做一个概述,但建议在实际安装的时候参考官网给出的说明。
PyTorch #
PyTorch 的安装参考 https://pytorch.org/get-started/locally/, 需要注意的一点是 CPU 版本需要指定 PyTorch 自己的网址。
pip3 install torch --index-url https://download.pytorch.org/whl/cpu # CPU
pip3 install torch # GPU
其 CPU 版本的安装包大小约 170 MB。
TensorFlow #
TensorFlow 的安装参考官网,需要使用合适的方式访问 https://www.tensorflow.org/install/pip
pip install tensorflow_cpu # CPU
pip install 'tensorflow[and-cuda]' # GPU
其 CPU 版本的安装包大小约 220 MB。
JAX #
pip install jax # CPU
pip install "jax[cuda12]" # GPU
Import #
import torch
import tensorflow as tf
from jax import numpy as jnp
import jax
功能和 API 对比 #
一个神经网络框架最基础的功能莫过于在 CPU 或 GPU 上进行多维矩阵运算和自动微分了。利用这两个功能可以手搓一个神经网络,但为了易用性,神经网络框架往往还内置了一些常用的网络结构,包括全连接层、卷积层等等。除此之外,监督式学习往往还涉及到了数据的预处理,最好有一些函数能够方便地读取数据、给数据分组。接下来,我们就从这些方面来看看这三个框架。
数组运算 #
PyTorch #
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
c = torch.matmul(a, b)
c
tensor([[19, 22],
[43, 50]])
- 数组操作和 NumPy 有一些不同。比如,你应该使用
torch.svd
而非torch.linalg.svd
- 数组类型不全是自动转换的,你依然需要确保数组被初始化为合适的类型(如
float32
而非int
)
torch.log(c)
tensor([[2.9444, 3.0910],
[3.7612, 3.9120]])
torch.svd(c.to(torch.float32))
torch.return_types.svd(
U=tensor([[-0.4033, -0.9150],
[-0.9150, 0.4033]]),
S=tensor([7.2069e+01, 5.5507e-02]),
V=tensor([[-0.6523, -0.7580],
[-0.7580, 0.6523]]))
TensorFlow #
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 6], [7, 8]])
c = tf.matmul(a, b)
c
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[19, 22],
[43, 50]], dtype=int32)>
- 数组操作和 NumPy 有一些不同。比如,你应该使用
tf.math.log
而非tf.log
- 数组类型需要手动转换,你需要确保数组被初始化为合适的类型(如
float32
而非int
)
tf.math.log(tf.cast(c, tf.float32))
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2.944439 , 3.0910425],
[3.7612002, 3.912023 ]], dtype=float32)>
tf.linalg.svd(tf.cast(c, tf.float32))
(<tf.Tensor: shape=(2,), dtype=float32, numpy=array([7.206939e+01, 5.550127e-02], dtype=float32)>,
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 0.40334514, 0.9150479 ],
[ 0.9150479 , -0.40334514]], dtype=float32)>,
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 0.6522966 , 0.75796384],
[ 0.75796384, -0.6522966 ]], dtype=float32)>)
JAX #
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
c = jnp.matmul(a, b)
c
Array([[19, 22],
[43, 50]], dtype=int32)
- 数组操作和 NumPy 几乎一致,除了一点细微的差别
jnp.log(c)
Array([[2.944439 , 3.0910425],
[3.7612002, 3.912023 ]], dtype=float32)
jnp.linalg.svd(c)
SVDResult(U=Array([[-0.40334493, -0.915048 ],
[-0.91504794, 0.40334517]], dtype=float32), S=Array([7.206938e+01, 5.550685e-02], dtype=float32), Vh=Array([[-0.6522967, -0.7579638],
[-0.7579638, 0.6522967]], dtype=float32))
自动微分 #
PyTorch #
只需要对输出做 backward
, 就能得到它对输入的导数
x = torch.tensor(3.0, requires_grad=True)
y = x**2
y.backward()
print(x.grad)
tensor(6.)
TensorFlow #
可以对 GradientTape
这一上下文中进行的操作做微分
x = tf.Variable(3.0)
with tf.GradientTape() as tape:
y = x**2
tape.gradient(y, x)
<tf.Tensor: shape=(), dtype=float32, numpy=6.0>
JAX #
JAX 更加函数化。你需要通过 jax.grad
得到函数 的导函数 ,再计算导函数 即可。
def loss_fn(x):
return x**2
grad_fn = jax.grad(loss_fn)
grad_fn(3.0)
Array(6., dtype=float32, weak_type=True)
JAX 还支持更加灵活的导数,包括高阶导数、Hessian 矩阵,Jacobian-vector 乘积(jax.jvp
,前向传播)和 vector-Jacobian 乘积(jax.vjp
,反向传播)等等,在此不一一列举,感兴趣可查阅 JAX Autodiff Cookbook。
常用网络结构 #
PyTorch #
torch.nn
模块内置各类层和神经网络组件。指定每一层的输入输出维度和激活函数即可。
import torch.nn as nn
model = nn.Sequential(
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(dim=-1)
)
model
Sequential(
(0): Linear(in_features=64, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=10, bias=True)
(3): Softmax(dim=-1)
)
model(torch.randn(1, 64))
tensor([[0.1130, 0.1304, 0.0751, 0.0859, 0.0861, 0.0781, 0.0827, 0.1015, 0.1505,
0.0967]], grad_fn=<SoftmaxBackward0>)
TensorFlow #
TensorFlow 2.0 引入了 Keras 库作为其高级 API。无需手动指定每一层的输入维度,但也因此多了一个 build
的步骤。
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
model = Sequential([
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
model.layers
[<Dense name=dense, built=False>, <Dense name=dense_1, built=False>]
model.build(input_shape=(None, 64))
model(tf.random.normal(shape=(1, 64)))
<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.05575525, 0.0179697 , 0.07684414, 0.06185716, 0.19958737,
0.14854056, 0.06055942, 0.06959403, 0.15985522, 0.14943717]],
dtype=float32)>
JAX #
本体不直接提供高层神经网络组件,需要使用第三方库例如 Flax
from flax import linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return nn.softmax(x)
model = MLP()
model
MLP()
与 PyTorch 和 TensorFlow 不同,JAX 的随机数 key
必须显式地作为参数传入,不存在一个全局的随机数状态。另外,Flax 中模型的参数也是作为参数传入的,模型本身不存储参数信息。这和 JAX 函数式编程的思想是高度一致的。
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1, 64))
params = model.init(key, x)
model.apply(params, x)
Array([[0.057403 , 0.07946321, 0.06227293, 0.06888593, 0.04725461,
0.12753344, 0.15180823, 0.09623848, 0.1865097 , 0.12263051]], dtype=float32)
数据处理工具 #
PyTorch #
PyTorch 的数据加载和处理模块设计较为简单和直接,提供 torch.utils.data
,内置 Dataset
和 DataLoader
可以实现高效的数据预处理。
from torch.utils.data import DataLoader, TensorDataset
x = torch.tensor([[1], [2], [3]])
y = torch.tensor([[4], [5], [6]])
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
print(batch)
[tensor([[1],
[2]]), tensor([[4],
[5]])]
[tensor([[3]]), tensor([[6]])]
TensorFlow #
提供 tf.data
API,设计较为复杂,但功能强大。提供了丰富的 API 和高级功能,如数据流水线(pipeline)、并行处理、缓存等。
data = tf.data.Dataset.from_tensor_slices([i for i in range(10)])
data = data.shuffle(buffer_size=10).batch(2).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
for batch in data:
print(batch)
tf.Tensor([4 6], shape=(2,), dtype=int32)
tf.Tensor([8 1], shape=(2,), dtype=int32)
tf.Tensor([0 5], shape=(2,), dtype=int32)
tf.Tensor([7 3], shape=(2,), dtype=int32)
tf.Tensor([9 2], shape=(2,), dtype=int32)
JAX #
JAX 没有自带的数据处理工具,但可以安装 tensorflow_datasets
库,从而获得和 TensorFlow 类似的数据处理体验。
其他生态 #
TensorFlow #
TensorFlow 在工业化部署上有较大优势,提供了包括 TensorFlow Lite(移动端), TensorFlow.js(浏览器端)等的多种部署方案。
此外 TensorFlow Extended (TFX) 提供了完整的端到端生产工作流,TensorBoard 提供了丰富的训练可视化工具。
PyTorch #
PyTorch 对图像(TorchVision)、音频(TorchAudio)、自然语言(TorchText)都有着强大的支持。
此外 TorchServe 也提供了大量模型的服务端部署功能。
JAX #
JAX 依然是一个比较新的框架,在生态方面处于一定的劣势。但是近年来,越来越多的第三方库除了对 TensorFlow 和 PyTorch 的支持以外,也加入了 JAX 的支持,包括 Keras, Hugging Face 的 Transformers 和 Diffusers 等等。
我该怎么选? #
- TensorFlow:部署和产业化友好,适合大规模分布式训练和生产环境。
- PyTorch:科研和开发模型的首选,强调代码简洁性和动态性。
- JAX:与 NumPy 高度一致,高效灵活的数学计算和自动微分。
另外,Keras? #
Keras 严格意义上不是一个独立的神经网络框架,它是依赖于 JAX、PyTorch 或 TensorFlow 作为后端,而它自己扮演的是高级封装的角色。Keras 库最早由谷歌的程序员 Francois Chollet 在 2015 年开发,TensorFlow 问世之后从上古框架 Theano 换到了 TensorFlow。TensorFlow 2.0 出来之后被正式“收编”为 TensorFlow 的官方高级 API。不过后来 Keras 后来又从 TensorFlow 的代码库分离了出来,重新成为了一个独立项目。自从 2023 年 Keras 3.0 发布以来,Keras 正式支持了 JAX、PyTorch 和 TensorFlow 作为后端。
所以说 Keras 从定位上讲更像 Flax,我们这里限于篇幅就不展开介绍 Keras 了。
(本文相关代码可以在 AllanChain/nn-frameworks 找到)