图书目录

目    录

第Ⅰ部分  入门基础

第1章  何时以及为什么使用JAX  3

1.1  使用JAX的理由  6

1.1.1  计算性能  7

1.1.2  函数式方法  9

1.1.3  JAX生态系统  10

1.2  JAX与NumPy的区别  11

1.2.1  JAX作为NumPy  12

1.2.2  可组合的变换  12

1.3  JAX与TensorFlow和PyTorch的区别  14

1.4  本章小结  16

第2章  你的第一个JAX程序  17

2.1  一个简单的机器学习问题:手写数字分类  18

2.2  JAX深度学习项目概览  19

2.3  加载和准备数据集  20

2.4  在JAX中构建一个简单的神经网络  22

2.4.1  神经网络初始化  24

2.4.2  神经网络前向传播  25

2.5  vmap:自动向量化计算以支持批处理  28

2.6  自动微分:如何在不手动计算导数的情况下计算梯度  30

2.6.1  损失函数  32

2.6.2  获取梯度  33

2.6.3  梯度更新步骤  33

2.6.4  训练循环  34

2.7  JIT:将代码编译为更快的版本  36

2.8  保存和部署模型  37

2.9  纯函数和可组合的转换:它们为什么重要  39

2.10  练习  40

2.11  本章小结  40

第Ⅱ部分  JAX核心机制

第3章  数组操作  43

3.1  使用NumPy数组进行图像处理  44

3.1.1  将图像加载到NumPy数组中  45

3.1.2  对图像执行基本预处理操作  48

3.1.3  向图像添加噪声  50

3.1.4  实现图像滤波  51

3.1.5  将张量保存为图像文件  55

3.2  JAX中的数组  57

3.2.1  切换到JAX类似NumPy的API  57

3.2.2  什么是Array  58

3.2.3  与设备相关的操作  61

3.2.4  异步调度  65

3.2.5  在TPU上运行计算  66

3.3  与NumPy的区别  69

3.3.1  不可变性  70

3.3.2  类型  73

3.4  高级接口与低级接口:jax.numpy和jax.lax  76

3.4.1  控制流原语  77

3.4.2  类型提升  78

3.5  练习  79

3.6  本章小结  79

第4章  计算梯度  81

4.1  获取导数的不同方法  82

4.1.1  手动求导  83

4.1.2  符号微分  84

4.1.3  数值微分  85

4.1.4  自动微分  87

4.2  使用自动微分计算梯度  89

4.2.1  在TensorFlow中使用梯度  91

4.2.2  在PyTorch中使用梯度  92

4.2.3  在JAX中使用梯度  92

4.2.4  高阶导数  100

4.2.5  多变量情况  102

4.3  前向模式和反向模式自动微分  104

4.3.1  计算轨迹  105

4.3.2  前向模式和jvp()  106

4.3.3  反向模式和vjp()  110

4.3.4  深入探索  113

4.4  本章小结  114

第5章  编译代码  115

5.1  使用编译  116

5.1.1  使用JIT编译  117

5.1.2  纯函数与编译过程  123

5.2  JIT内部机制  125

5.2.1  Jaxpr:JAX程序的中间表示形式  125

5.2.2  XLA  134

5.2.3  使用AOT编译  138

5.3  JIT的局限性  142

5.3.1  纯函数与非纯函数  142

5.3.2  精确数值  142

5.3.3  输入参数值依赖的条件控制  142

5.3.4  编译速度慢  142

5.3.5  类方法  144

5.3.6  简单函数  145

5.4  练习  146

5.5  本章小结  146

第6章  向量化代码  147

6.1  向量化函数的不同方法  148

6.1.1  朴素方法  149

6.1.2  手动向量化  151

6.1.3  自动向量化  151

6.1.4  性能比较  152

6.2  控制vmap()行为  154

6.2.1  控制映射的数组轴  154

6.2.2  控制输出数组的轴  157

6.2.3  使用命名参数  158

6.2.4  使用装饰器风格  160

6.2.5  使用集合操作  161

6.3  vmap()的实际应用案例  162

6.3.1  批数据处理  162

6.3.2  批量化神经网络模型  164

6.3.3  每个样本的梯度  165

6.3.4  向量化循环  166

6.4  本章小结  169

第7章  并行化计算  171

7.1  使用pmap()并行化计算  172

7.1.1  问题设置  172

7.1.2  像使用vmap一样使用pmap  175

7.2  控制pmap()的行为  180

7.2.1  控制输入和输出的映射轴  181

7.2.2  使用命名轴和集合操作  186

7.3  数据并行的神经网络训练示例  193

7.3.1  准备数据和神经网络结构  194

7.3.2  实现数据并行训练  196

7.4  使用多主机配置  201

7.5  本章小结  206

第8章  使用张量切分  209

8.1  张量分片基础  210

8.1.1  设备网格  212

8.1.2  位置分片  213

8.1.3  二维网格示例  213

8.1.4  使用复制  217

8.1.5  分片约束  219

8.1.6  命名切分  221

8.1.7  设备放置策略与错误  222

8.2  使用张量分片的多层感知机(MLP)  224

8.2.1  八路数据并行  224

8.2.2  四路数据并行,双路张量并行  226

8.3  本章小结  229

第9章  JAX中的随机数  231

9.1  生成随机数据  232

9.1.1  载入数据集  233

9.1.2  生成随机噪声  235

9.1.3  执行随机增强  239

9.2  与NumPy的区别  241

9.2.1  NumPy的工作原理  241

9.2.2  NumPy中的种子和状态  243

9.2.3  JAX PRNG  246

9.2.4  JAX PRNG高级配置  252

9.3  在实际应用中生成随机数  253

9.3.1  构建一个完整的数据增强管道  253

9.3.2  为神经网络生成随机初始化  255

9.4  本章小结  256

第10章  处理pytree  257

10.1  将复杂数据结构表示为pytree  258

10.2  处理pytree的函数  262

10.2.1  使用tree_map()  263

10.2.2  扁平化/还原pytree  265

10.2.3  使用tree_reduce()  267

10.2.4  转置pytree  268

10.3  创建自定义pytree节点  271

10.4  本章小结  274

第Ⅲ部分   生态系统

第11章  高级神经网络库  277

11.1  使用MLP进行MNIST图像分类  278

11.1.1  Flax中的MLP  278

11.1.2  Optax梯度变换库  284

11.1.3  使用Flax训练神经网络  286

11.2  使用ResNet进行图像分类  290

11.2.1  在Flax中管理状态  290

11.2.2  使用Orbax保存和加载模型  296

11.3  使用Hugging Face生态系统  298

11.3.1  使用Hugging Face Model Hub中的预训练模型  299

11.3.2  进一步探索:微调与再训练  304

11.3.3  使用diffusers库  306

11.4  本章小结  310

第12章  JAX生态系统的其他成员  313

12.1  深度学习生态系统  314

12.1.1  高层神经网络库  314

12.1.2  JAX中的大型语言模型(LLM)  315

12.1.3  工具库  317

12.2  机器学习模块  319

12.2.1  强化学习  319

12.2.2  其他机器学习库  320

12.3  其他领域的JAX模块  321

12.4  本章小结  322

附录A  安装JAX  325

附录B  使用Google Colab  329

附录C  使用Google Cloud TPU  331

附录D  实验性并行化  335