欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

第 29 章 使用 JAX 进行数值计算

最编程 2024-07-02 15:28:37
...

前面几章对JAX的基本运算和操作规则作了介绍,除了这些初步内容之外,JAX运行时还有很多特征。接下来通过几个章节介绍JAX使用过程中需要注意的细节。

JAX在设计之初的目的就是为了取代NumPy进行数值计算,并在数值计算的基础上期望能够充分利用硬件资源从而提高数值计算速度。本章深入介绍JAX数值计算的一些细节,从底层角度来讲解器使用规则和方法。

jax.grad函数使用细节

前面几章曾用大量篇幅介绍如是jax.grad来进行自动求导。但jaxx.grad的自动求导方法与Python库本身(如NumPy)的求导方法不同。这些库使用数值本身来计算梯度,而jax.grad则直接使用函数,更接近于底层的数学计算。一旦习惯了这种方式,会觉得很自然。代码中损失函数实际上就是参数和数据的函数,求函数的梯度就像在数学中一样。

jax.grad函数必须使用浮点型

jax.grad函数待处理的数据必须是浮点型数据。比如如下代码,


import jax

def function(x):
    # f(x) = x³ + x²
    return x ** 3 + x ** 2

def test():
    # f'(x) = 3x² + 2x
    function_grad = jax.grad(function)

    result = function_grad(2)
    print("result = ", result)

    result = function_grad(3)
    print("result = ", result)

if __name__ == "__main__":

    test()

运行结果打印输出如下,


result = function_grad(2)
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

可以看到,运行出错。根据错误提示,可以使用参数allow_int让函数接受整型输入,更该代码如下,


import jax

def function(x):
    # f(x) = x³ + x²
    return x ** 3 + x ** 2

def test():

    # f'(x) = 3x² + 2x
    function_grad = jax.grad(function, allow_int = True)

    result = function_grad(2)
    print("result = ", result)

    result = function_grad(3)
    print("result = ", result)

if __name__ == "__main__":

    test()

运行结果打印输出如下,


result = function_grad(2)
TypeError: grad requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got int32. For differentiation of functions with integer outputs, use jax.vjp directly.

仍然出错。这可能是jax本身的bug?至于jax.vjp将在后续章节进行讨论。现在,更改源代码,把整型变成浮点型,


result = function_grad(2.)
print("result = ", result)

result = function_grad(3.)
print("result = ", result)

运行结果打印输出如下,


result =  16.0
result =  33.0

符合手动计算导数的预期。也就是说,输入的数据必须是浮点型,不可以为整数。浮点型数据兼容整型,反过来则不行。

同时获取函数值和求导值

上面的例子直接输出了求导结果,如果既想输出求导值又想输出函数值,则可以使用jax.value_and_grad函数。代码如下,


import jax

def function(x):
    
    # f(x) = x       + x
    return x ** 3 + x ** 2

def test():
    
    # f'(x) = 3x       + 2x
    function_value_and_grad = jax.value_and_grad(function)
    (value, gradient) = function_value_and_grad(2.)
    
    print("Function value = ", value, "gradient = ", gradient)
    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


Function value =  12.0 gradient =  16.0

多元函数求导

前面例子主要集中在一元函数的求导,即只有一个未知数的函数的求导,而数学中还存在多元函数,即有多元函数求导的需求。

f\left( x, y \right) = 2x^2 + 3y^3\
d\left( x \right) = 4x + 0 = 4x
d\left( y \right) = 0 + 9y^2 = 9y^2
上面是一个二元三次函数及其求导公式,使用代码分别对其求导。


import jax

def function(x, y):
    
    # f(x, y) = 2x² + 3y³
    return 2 * x ** 2 + 3 * y ** 3

def test():
    
    x = 2.
    y = 3.
    
    # f(x, y) = 2x² + 3y³
    # Partial Derivatives
    # d(x) = 4x + 0 = 4x
    # d(y) = 0 + 9y² = 9y²
    function_grad = jax.grad(function)
    value = function_grad(x, y)
    
    print("Derivative value = ", value)
    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


Derivative value =  8.0

更改一下代码如下,


function_grad = jax.grad(function, argnums = 0)

运行结果打印输出如下,


Derivative value =  8.0

再更改一下代码,


function_grad = jax.grad(function, argnums = 1)

运行结果打印输出如下,


Derivative value =  81.0

可见默认情况下,argnums = 0,即对函数的第一个参数x求导,实际上是对x求偏导;当argnums = 1,则是对函数的第二个参数y求导,实际上是对y求偏导。

\frac{\partial f\left( x, y \right)}{\partial x}
\frac{\partial f\left( x, y \right)}{\partial y}

如果相对2个参数同时求导,如何处理呢?代码如下所示,


import jax

def function(x, y):

    # f(x, y) = 2x² + 3y³
    return 2 * x ** 2 + 3 * y ** 3

def test():

    x = 2.
    y = 3.

    # (d(x), d(y)) = (4x, 9y²)
    function_grad = jax.grad(function, argnums = (0, 1))
    dx, dy = function_grad(x, y)

    print("dx = ", dx, ", dy = ", dy)

if __name__ == "__main__":

    test()

运行结果打印输出如下,


dx =  8.0 , dy =  81.0

结果与手动代入的数学计算一致。

通过上面代码的运行结果可以看到,当加入argnums参数设置时。jax.grad会根据指定的不同位置进行求导。也就是说,解决多元函数求导问题的方法是通过设置参数argnums的形状来指定对不同的参数求导。

下面再通过一个对三元函数求导来加深理解。


import jax

def function(x, y, z):

    # f(x, y, z) = 2x² + 3y³ + 4z?~A?
    return 2 * x ** 2 + 3 * y ** 3 + 4 * z ** 4

def test():

    x = 2.
    y = 3.
    z = 4.

    # Partial Derivates = (dx, dy, dz) = (4x, 9y², 16z³)
    function_grad = jax.grad(function, argnums = (0, 1, 2))
    dx, dy, dz = function_grad(x, y, z)

    print("dx = ", dx, " dy = ", dy, " dz = ", dz)

if __name__ == "__main__":

    test()

运行结果打印输出如下,


dx =  8.0  dy =  81.0  dz =  1024.0

可以看到,最终生成了三个结果,依次是对dx、dy、dz求导后的计算结果。注意argnums = (0, 1, 2)里的0,1,2分别代表三个参数的索引位置,必须严格对应。如果改成argnums = (0, 1)值回返回两个结果,其他情况感兴趣可以自行测试。

多个返回值的函数求导

一般函数仅包含一个返回值,但有些情况下会包含两个或者两个以上的返回值,jax.grad函数同样也提供了处理的方法。代码如下,


import jax

def function(x, y):

    # f(x, y) = 2x² * 3y³
    # g(x, y) = 4x³ + 5y?~A?
    return 2 * x ** 2 * 3 * y ** 3, 4 * x ** 3 + 5 * y ** 4

def test():

    x = 2.
    y = 3.

    # Derivatives = [(dx, dy), (dx, dy)] = [(4x * 3y³, 2x² * 9y²), (12x², 20y³)] = [(648,,
 648), (48, 540)]
    function_grad = jax.grad(function, argnums = (0, 1), has_aux = True)
    derivatives = function_grad(x, y)

    print("derivatives = ", derivatives)

if __name__ == "__main__":

    test()

运行结果打印输出如下,


derivatives =  ((Array(648., dtype=float32, weak_type=True), Array(648., dtype=float32, weak_type=True)), Array(437., dtype=float32, weak_type=True))

代码中,通过设置has_aux = True告诉jax.grad函数有“辅助的”,即更多的返回值。

注意,本以为会返回两个函数的导数,但实际上Array(437., dtype=float32, weak_type=True)是第二个函数的原函数的结果。

经过在GitHub JAX项目社区问答,代码更新如下,


import jax

def function(x, y):

    # f(x, y) = 2x² * 3y³
    # g(x, y) = 4x³ + 5y?~A?
    return 2 * x ** 2 * 3 * y ** 3, 4 * x ** 3 + 5 * y ** 4

def test():

    x = 2.
    y = 3.

    # Derivatives = [(dx, dy), (dx, dy)] = [(4x * 3y³, 2x² * 9y²), (12x², 200
y³)] = [(648, 648), (48, 540)]
    #function_grad = jax.grad(function, argnums = (0, 1), has_aux = True)
    # derivatives = function_grad(x, y)

    df1_dx, df2_dx = jax.jacobian(function, argnums = 0)(x, y)
    df1_dy, df2_dy = jax.jacobian(function, argnums = 1)(x, y)

    print("derivatives = ", ((df1_dx, df1_dy), (df2_dx, df2_dy)))

if __name__ == "__main__":

    test()

再谈副作用代码

经过这么多章节的学习,绝大多数情况下,NumPy的API可以无缝对接jax.numpy,比如numpy.arange与jax.numpy.arange、numpy.ones与jax.numpy.ones等。

最重要的区别是JAX被设计为函数式,就像函数式编程语言一样。这背后的原因是JAX支持的程序转换类型在函数式程序中更可行。使用函数式编程的一个好处是不需要编写带有副作用的代码。副作用是指没有出现在输出中的函数所带来的其他影响。一个明显的例子如下所示,


import numpy

x = numpy.array([1, 2, 3])

def update_in_place(x):

    x[0] = 10

    return None

def test():

    update_in_place(x)

    print(“x = “, x)

if __name__ == "__main__":

    test()

运行结果打印输出如下,


x =  [10  2  3]

可以明显看到程序运行后,外部数据x数组的元素被修改,这就对外部产生了影响,造成了副作用。

而JAX在设计之初就确定了尤其包装的数据无法被修改(immutable),因此在一定程度上杜绝了副作用的产生。

前面在讲解纯函数(Pure Function)时提到,无副作用的代码有时被称为纯函数。一个无副作用的纯函数就像船过水无痕。纯函数由于需要额外在存储中间生成一个数据,会不会降低JAX的效率呢?严格来说会。

然而JAX计算通常是在JAX使用JIT编译缓存后进行,对于编译器来说,新生成的是一个必须生成的“数据模板”,而在运行时只需要将数据注入已经生成的“模板”即可。

注意,如果有必要,可以将有副作用的Python代码和纯函数代码混合使用。

结论

本章讨论了jax.grad求导细节,包含必须使用浮点型输入、同时获得函数值和求导值、多元函数求导已经多个函数返回值求导。另外也再次讨论了副作用代码和纯函数,介绍了为什么JAX要使用纯函数。