目 录
第Ⅰ部分 入门基础
第1章 何时以及为什么使用JAX 3
1.1 使用JAX的理由 6
1.1.1 计算性能 7
1.1.2 函数式方法 9
1.1.3 JAX生态系统 10
1.2 JAX与NumPy的区别 11
1.2.1 JAX作为NumPy 12
1.2.2 可组合的变换 12
1.3 JAX与TensorFlow和PyTorch的区别 14
1.4 本章小结 16
第2章 你的第一个JAX程序 17
2.1 一个简单的机器学习问题:手写数字分类 18
2.2 JAX深度学习项目概览 19
2.3 加载和准备数据集 20
2.4 在JAX中构建一个简单的神经网络 22
2.4.1 神经网络初始化 24
2.4.2 神经网络前向传播 25
2.5 vmap:自动向量化计算以支持批处理 28
2.6 自动微分:如何在不手动计算导数的情况下计算梯度 30
2.6.1 损失函数 32
2.6.2 获取梯度 33
2.6.3 梯度更新步骤 33
2.6.4 训练循环 34
2.7 JIT:将代码编译为更快的版本 36
2.8 保存和部署模型 37
2.9 纯函数和可组合的转换:它们为什么重要 39
2.10 练习 40
2.11 本章小结 40
第Ⅱ部分 JAX核心机制
第3章 数组操作 43
3.1 使用NumPy数组进行图像处理 44
3.1.1 将图像加载到NumPy数组中 45
3.1.2 对图像执行基本预处理操作 48
3.1.3 向图像添加噪声 50
3.1.4 实现图像滤波 51
3.1.5 将张量保存为图像文件 55
3.2 JAX中的数组 57
3.2.1 切换到JAX类似NumPy的API 57
3.2.2 什么是Array 58
3.2.3 与设备相关的操作 61
3.2.4 异步调度 65
3.2.5 在TPU上运行计算 66
3.3 与NumPy的区别 69
3.3.1 不可变性 70
3.3.2 类型 73
3.4 高级接口与低级接口:jax.numpy和jax.lax 76
3.4.1 控制流原语 77
3.4.2 类型提升 78
3.5 练习 79
3.6 本章小结 79
第4章 计算梯度 81
4.1 获取导数的不同方法 82
4.1.1 手动求导 83
4.1.2 符号微分 84
4.1.3 数值微分 85
4.1.4 自动微分 87
4.2 使用自动微分计算梯度 89
4.2.1 在TensorFlow中使用梯度 91
4.2.2 在PyTorch中使用梯度 92
4.2.3 在JAX中使用梯度 92
4.2.4 高阶导数 100
4.2.5 多变量情况 102
4.3 前向模式和反向模式自动微分 104
4.3.1 计算轨迹 105
4.3.2 前向模式和jvp() 106
4.3.3 反向模式和vjp() 110
4.3.4 深入探索 113
4.4 本章小结 114
第5章 编译代码 115
5.1 使用编译 116
5.1.1 使用JIT编译 117
5.1.2 纯函数与编译过程 123
5.2 JIT内部机制 125
5.2.1 Jaxpr:JAX程序的中间表示形式 125
5.2.2 XLA 134
5.2.3 使用AOT编译 138
5.3 JIT的局限性 142
5.3.1 纯函数与非纯函数 142
5.3.2 精确数值 142
5.3.3 输入参数值依赖的条件控制 142
5.3.4 编译速度慢 142
5.3.5 类方法 144
5.3.6 简单函数 145
5.4 练习 146
5.5 本章小结 146
第6章 向量化代码 147
6.1 向量化函数的不同方法 148
6.1.1 朴素方法 149
6.1.2 手动向量化 151
6.1.3 自动向量化 151
6.1.4 性能比较 152
6.2 控制vmap()行为 154
6.2.1 控制映射的数组轴 154
6.2.2 控制输出数组的轴 157
6.2.3 使用命名参数 158
6.2.4 使用装饰器风格 160
6.2.5 使用集合操作 161
6.3 vmap()的实际应用案例 162
6.3.1 批数据处理 162
6.3.2 批量化神经网络模型 164
6.3.3 每个样本的梯度 165
6.3.4 向量化循环 166
6.4 本章小结 169
第7章 并行化计算 171
7.1 使用pmap()并行化计算 172
7.1.1 问题设置 172
7.1.2 像使用vmap一样使用pmap 175
7.2 控制pmap()的行为 180
7.2.1 控制输入和输出的映射轴 181
7.2.2 使用命名轴和集合操作 186
7.3 数据并行的神经网络训练示例 193
7.3.1 准备数据和神经网络结构 194
7.3.2 实现数据并行训练 196
7.4 使用多主机配置 201
7.5 本章小结 206
第8章 使用张量切分 209
8.1 张量分片基础 210
8.1.1 设备网格 212
8.1.2 位置分片 213
8.1.3 二维网格示例 213
8.1.4 使用复制 217
8.1.5 分片约束 219
8.1.6 命名切分 221
8.1.7 设备放置策略与错误 222
8.2 使用张量分片的多层感知机(MLP) 224
8.2.1 八路数据并行 224
8.2.2 四路数据并行,双路张量并行 226
8.3 本章小结 229
第9章 JAX中的随机数 231
9.1 生成随机数据 232
9.1.1 载入数据集 233
9.1.2 生成随机噪声 235
9.1.3 执行随机增强 239
9.2 与NumPy的区别 241
9.2.1 NumPy的工作原理 241
9.2.2 NumPy中的种子和状态 243
9.2.3 JAX PRNG 246
9.2.4 JAX PRNG高级配置 252
9.3 在实际应用中生成随机数 253
9.3.1 构建一个完整的数据增强管道 253
9.3.2 为神经网络生成随机初始化 255
9.4 本章小结 256
第10章 处理pytree 257
10.1 将复杂数据结构表示为pytree 258
10.2 处理pytree的函数 262
10.2.1 使用tree_map() 263
10.2.2 扁平化/还原pytree 265
10.2.3 使用tree_reduce() 267
10.2.4 转置pytree 268
10.3 创建自定义pytree节点 271
10.4 本章小结 274
第Ⅲ部分 生态系统
第11章 高级神经网络库 277
11.1 使用MLP进行MNIST图像分类 278
11.1.1 Flax中的MLP 278
11.1.2 Optax梯度变换库 284
11.1.3 使用Flax训练神经网络 286
11.2 使用ResNet进行图像分类 290
11.2.1 在Flax中管理状态 290
11.2.2 使用Orbax保存和加载模型 296
11.3 使用Hugging Face生态系统 298
11.3.1 使用Hugging Face Model Hub中的预训练模型 299
11.3.2 进一步探索:微调与再训练 304
11.3.3 使用diffusers库 306
11.4 本章小结 310
第12章 JAX生态系统的其他成员 313
12.1 深度学习生态系统 314
12.1.1 高层神经网络库 314
12.1.2 JAX中的大型语言模型(LLM) 315
12.1.3 工具库 317
12.2 机器学习模块 319
12.2.1 强化学习 319
12.2.2 其他机器学习库 320
12.3 其他领域的JAX模块 321
12.4 本章小结 322
附录A 安装JAX 325
附录B 使用Google Colab 329
附录C 使用Google Cloud TPU 331
附录D 实验性并行化 335
