译 者 序
在当今快速发展的深度学习领域,JAX作为一种新兴的工具,凭借其高性能、灵活性和优雅的编程范式,吸引了越来越多的研究者和开发者。它不仅简化了在CPU、GPU和TPU上开发和部署的过程,还通过独特的功能(如自动微分、自动向量化和即时编译)为深度学习、优化和科学计算等多个领域提供了强大的支撑。
《JAX深度学习》正是在这样的背景下应运而生。本书旨在为读者提供一条从入门到精通JAX的完整学习路径。无论你是深度学习从业者、研究人员,还是物理学、化学等领域的跨学科开发者,本书都能够为你提供清晰的指导和丰富的实践示例,帮助你快速掌握JAX的关键概念,并将其应用于实际项目中。
本书内容主要分为三大部分:首先,通过概览介绍了JAX的独特之处,以及快速上手这一工具的方法;其次,深入剖析了JAX的核心功能,包括张量操作、自动微分、编译优化和并行计算等技术细节;最后,介绍了JAX的生态系统,以及围绕这一框架构建的高级工具和库,如Flax和Optax等。本书以深入浅出的方式,结合丰富的代码实例,帮助读者理解并掌握这些复杂的技术内容。
对于希望通过JAX来构建神经网络的开发者,本书提供了清晰的路径:从基础概念到复杂模型的实现,无论是初学者还是资深研究者都能从中受益。与此同时,那些关注高性能计算的专业人士也能从并行化和张量分片等高级章节中获得灵感。值得一提的是,本书还为所有示例代码提供了在线资源和Colab notebook(笔记本),读者可以直接运行这些代码,以便更好地理解相关内容。
学习JAX不仅是掌握一门强大的计算工具,更是迈向高效科学计算和前沿研究领域的一次重要尝试。在未来,JAX及其生态系统将持续发展,不断拓展深度学习和科学计算的边界。本书的每一章节和案例均以此为目标,帮助读者在理论与实践中找到最优平衡。
我衷心希望《JAX深度学习》能够成为你学习和应用JAX的得力助手。无论你是初次接触JAX,还是希望在现有技能上更进一步,本书都将为你提供明确的方向和有力的支持。让我们共同开启这段关于JAX的探索之旅,迈向更加高效且创新的计算新时代!
最后,我想衷心感谢清华大学出版社的编辑老师们,感谢他们帮助我出版了多种关于机器学习、人工智能、云计算以及高性能计算的书籍,感谢他们为我提供一种全新的与大家分享知识的方式。我还要感谢我的好友黎阳诗,感谢他耐心地帮助我校对并精心润色文稿。
殷海英
埃尔塞贡多,加利福尼亚州
谨以此书献给我的父母,感谢他们鼓励我追随自己的兴趣,为我提供了大量的优秀书籍,并在电脑还只是奢侈品(主要用于娱乐)的时代,赠予我人生中第一台电脑。
作 者 简 介
格里戈里•萨普诺夫(Grigory Sapunov)是Intento公司的联合创始人兼首席技术官(CTO)。他是一位拥有超过20年从业经验的软件工程师,获得人工智能博士学位,同时持有机器学习领域谷歌开发者专家(Google Developer Expert,GDE)认证。
致 谢
本书的编写所花费的时间远超我的预期。在这期间,我调整了一些内容,加上JAX的版本也发生了一些变化,导致部分章节不得不重新编写。但现在,一切终于圆满完成了!
首先,我要感谢我的家人,特别是我的妻子Mila和我的两个孩子,Danya和Fedya。你们一直在我缺少陪伴的日子里默默承受,却始终给予我支持和鼓励。
我要感谢亚美尼亚人民,感谢你们的热情好客,我们在亚美尼亚生活过的一段时间里,你们给予了我们极大的温暖与支持。特别感谢耶烈万的科技创业社区,感谢你们的帮助与支持。感谢Hrant Khachatrian、Zaven Navoyan、Arsen Yeghiazaryan、Andranik Khachatryan、Ashot Arzumanyan、Ash Vardanian、Adam Bittlingmayer、Artur Aleksanyan、Erik Arakelyan、Karén Gyulbudaghyan,以及其他许多朋友,衷心感谢你们!
感谢亚美尼亚企业局(Enterprise Armenia)和亚美尼亚国家投资促进局(National Investment Promotion Agency of Armenia)。你们的工作非常出色,所提供的帮助对我而言意义非凡。
我要感谢Manning出版社的编辑们:Patrick Barb、Becky Whitney和Frances Lefkowitz。尽管经历了三位编辑的更替和多次修改的过程,但你们每一位都为本书增色添彩。同时,我也要感谢Mike Stephens和Marjan Bace,感谢你们从我最初的提案阶段就一直相信本书的价值。
我要感谢我的技术编辑Nick McGreivy,他不仅是普林斯顿大学的博士生,专攻等离子体物理学,还在他的研究中使用JAX,致力于优化科学实验并将深度学习集成到数值模拟中。感谢技术审校者Kostas Passadis,以及我的审稿人Arslan Gabdulkhakov、Chansung Park、Fillipe Dornelas、James Black、James Wang、Jun Jiang、Keith Kim、Lucian-Paul Torje、Maxim Volgin、Najeeb Arif、Or Golan、Ritobrata Ghosh、Seunghyun Lee、Simone De Bonis、Stephen Oates、Tony Holdroyd、Vidhya Vinay和Vojta Tuma。你们提供了许多宝贵的意见和建议,帮助我改进了本书。尽管大家的帮助极为重要,但书中仍可能存在一些疏漏,对此,责任完全由我个人承担。
最后,我还要感谢我的GDE(Google开发者专家)朋友们以及Google对这个项目的支持。GDE社区真的很棒!许多GDE成员审阅了本书的早期版本,并提供了极具价值的反馈。特别感谢David Cardozo,感谢他提供的精彩反馈!
关于封面插图
《JAX深度学习》一书封面上的插图标题为“La Bearnaise”(或“The Bearnese”),这是一个来自法国比利牛斯山脉特定地区的民族。该插图选自Jacques Grasset de Saint-Sauveur于1797年出版的一部收藏作品,经过精细的手工绘制和着色。
在那个年代,人们仅凭衣着就可以轻松辨别他们的居住地、职业或社会地位。Manning出版社独具匠心,选择以这些基于几个世纪前丰富多彩的区域文化的书籍封面作为载体,来颂扬计算机行业的创造力和主动性。这些封面的灵感源于诸如本图集这样的珍贵收藏,将昔日的文化重新带入现代生活。
序 言
JAX是由Google开发的一个功能强大的Python库,其广泛应用于深度学习和高性能计算。它在机器学习研究中备受青睐,是仅次于TensorFlow和PyTorch的第三大深度学习框架。值得一提的是,JAX已成为DeepMind等公司的首选框架,Google的研究工作也越来越依赖于JAX。
JAX在深度学习中对函数式编程的重视令人印象深刻。它提供了强大的函数变换功能,包括梯度计算、通过XLA实现的JIT编译、自动向量化和并行化。JAX支持GPU和TPU,性能表现非常出色。
现在正是深入学习JAX的绝佳时刻,因为它的生态系统正在快速扩展。尽管JAX问世已有数年,但对于初学者而言,仍然缺乏全面的学习资源。虽然JAX的官方网站提供了可靠的文档和一个支持性的社区,但将这些内容整合起来,特别是在与其他库进行集成时,可能会让人感到有些困惑。
本书是为渴望掌握JAX的读者精心编写的。本书目标是将关键的信息集中在一起,帮助你理解JAX的核心概念,提升技能,并增强你在项目和研究中应用JAX的能力。
本书假设你已经具备深度学习的基础知识以及熟练的Python编程能力,因此不会涉及深度学习的基础内容,因为市面上已有大量相关资源可供参考。本书的重点是JAX,在必要时会简要介绍一些关键的深度学习概念,以帮助来自非深度学习背景的读者,如物理学等领域的读者。
JAX不只是一个深度学习框架。随着其模块范围不断扩展,JAX在可微分编程、大规模物理模拟等领域的潜力逐渐显现。我希望本书能够为对这类应用感兴趣的读者提供帮助。
JAX的最新发展,在本书的多个章节中都有相应的体现。读者不必担心未来可能出现的变化,因为本书中讲解的核心知识将继续适用于JAX的后续版本。
关 于 本 书
《JAX深度学习》旨在帮助读者理解并开始在项目和研究中使用JAX。本书将关键的知识点集中在一起,通过一系列易于理解的实例,逐步引导读者掌握JAX的概念,帮助读者建立对这一主题的直观理解。
本书读者对象
《JAX深度学习》主要面向熟悉PyTorch和TensorFlow等框架的深度学习从业者和研究人员,尤其是希望开始使用JAX的读者。读者应具备一定的深度学习基础,并且能够熟练使用Python。来自其他领域的研究人员(如物理学或优化领域),以及专注于深度学习、数值优化或分布式计算的研究生,也能从本书中获得有关学习和实践方面的帮助。
本书结构:路线图
本书内容分为三个部分,共12章。
第Ⅰ部分为引言及JAX概览。
● 第1章回答了一个关键问题:“为什么选择JAX?”我们将探讨JAX是什么,与TensorFlow、PyTorch等其他框架相比,它的优缺点,以及在何种情况下JAX可能是你项目中最合适的工具。
● 第2章将引导读者完成与JAX的首次实践。我们将构建一个用于图像分类的简单神经网络,并介绍关键概念,如JAX的自动向量化、梯度计算和即时(Just-In-Time,JIT)编译等转换功能。读者还将学习如何保存和加载模型,并理解JAX中纯函数与非纯函数的区别。
第Ⅱ部分将深入探讨JAX的核心功能。
● 第3章将深入探讨深度学习的核心工具:张量或多维数组。我们将对比NumPy数组与JAX数组,讨论如何在CPU、GPU和TPU等不同硬件上使用它们,并解释在NumPy和JAX之间转换代码时的细微差别。
● 第4章将重点讨论计算梯度的关键任务,这对于训练神经网络而言至关重要。我们将比较各种微分方法,深入探讨JAX的自动微分功能,并探索前向自动微分和反向自动微分模式。
● 第5章将介绍如何使用JIT编译优化代码性能。我们将深入探讨JIT的工作原理、它与XLA编译器的交互方式,以及如何应对可能的限制。
● 第6章将介绍自动向量化,这是一种高效处理批量数据的强大技术。我们将探讨不同的向量化方法,讨论如何控制JAX的vmap()转换,并分析自动向量化在实际场景中的优势。
● 第7章将深入探讨并行化,使读者能够在多个设备上同时运行计算。我们将讨论如何使用pmap()转换实现并行执行,控制其行为,并实现数据并行的神经网络训练。此外,我们还将探索在多主机配置上运行代码以应对大规模任务。
● 第8章将介绍张量分片,这是一种在JAX中实现并行化的现代高效方法。我们将解释如何利用XLA进行自动并行化,如何实现数据并行和张量并行来训练神经网络,并探讨这种技术的优势。
● 第9章将探讨JAX中生成随机数的重要主题。我们将比较JAX与NumPy在这方面的差异,讨论密钥在表示随机数生成器状态中的作用,并解释如何在实际应用中运用这些概念。
● 第10章将介绍pytree,这是JAX中用于表示复杂数据结构的强大工具。我们将讨论如何高效地使用pytree,利用相关函数对其进行操作,甚至根据特定需求创建自定义的pytree节点。
第Ⅲ部分将介绍围绕JAX构建的丰富多样的库和工具生态系统。
● 第11章将介绍Flax和Optax等高级神经网络库,它们为构建和训练复杂模型提供了便捷的抽象机制。我们将使用Flax构建一个简单的多层感知机(Multilayer Perceptron, MLP)和一个更高级的残差网络用于图像分类,并探索如何利用Hugging Face库处理Transformer和扩散模型。
● 第12章将更全面地介绍JAX生态系统,展示其支持多种机器学习任务的库,包括训练大语言模型( Large Language Model, LLM)、强化学习和进化计算等。我们还将探索JAX在物理学、化学等其他科学领域中的应用模块。
如果你是管理者,那么我建议阅读前两章,以了解JAX的优势、它与PyTorch和TensorFlow的区别,以及一个典型的JAX机器学习项目是怎样的。第12章同样是非技术性的,可以让你深入了解JAX的亮点所在。
对于渴望使用JAX构建神经网络的开发者,建议先从第2章开始,学习一个简单的深度学习示例;第3~6章则介绍了JAX的基础概念;第11章则提供了该生态系统中高级库的概述。读者可以根据个人兴趣选择任意顺序阅读其余章节。如果你目前还不关注并行化,那么可以跳过第7章和第8章——以后再回来阅读也不迟。如果你对随机数和pytree感兴趣,那么可以深入学习第9章和第10章,不过前面的章节已经为你提供了足够的基础知识,能让你快速上手。
关于代码
本书包含了许多源代码示例,既有编号的代码清单,也有嵌入普通文本中的代码。在这两种情况下,源代码都使用等宽字体格式,以便与普通文本区分开来。有时,为了突出章节中发生变化的部分,代码会以加粗形式显示,如在添加新功能时,会对原有代码行进行修改。
在许多情况下,我们对原始源代码进行了重新格式化;添加了换行符并调整了缩进,以适应书中的页面空间。在某些情况下,即便如此,代码仍可能过长,代码清单中可能会出现续行标记(➥)。此外,当代码在正文中进行描述时,源代码中的注释通常会被移除。许多代码清单还配有代码注释,以突出显示重要的概念。
读者可以通过本书的在线版本(liveBook)获取可执行的代码片段,网址是https://livebook.manning.com/book/deep-learning-with-jax。书中所有示例的完整代码可以从GitHub(https://github.com/che-shr-cat/JAX-in-Action)下载,也可通过扫描本书封底上的二维码下载。
几乎每一章都有对应的Colab notebook(或多个笔记本)。这些代码已在JAX版本0.4.14上进行了测试。
其他在线资源
最重要的信息来源是JAX文档(https://jax.readthedocs.io/en/latest/)。该文档更新频繁,读者可以在其中找到很多问题的答案。其他重要的信息来源包括GitHub上的讨论区(https://github.com/google/jax/discussions)和问题区(https://github.com/ google/jax/issues)。
