欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

理解Stein变分梯度下降的论文笔记

最编程 2024-08-08 08:32:53
...

论文地址:点这里。作者还提供了Stein变分梯度下降法的源码
Note:
源码不涉及深度学习,所以PyTorch用户或者TF用户都可以使用。

Stein变分梯度下降(SVGD)可以理解是一种和随机梯度下降(SGD)一样的优化算法。在强化学习算法中,Soft-Q-Learning使用了SVGD去优化,而Soft-AC选择了SGD去做优化。

Stein Variational Gradient Descent:A General Purpose Bayesian Inference Algorithm

  • Abstract
  • 1 Introduction
  • 2 Background
  • 3 Variational Inference Using Smooth Transforms
    • 3.1 Stein Operator as the Derivative of KL Divergence
    • 3.2 Stein Variational Gradient Descent
  • 4 Related Works
  • 5 Experiments
  • 6 Conclusion
  • 源码解析

Abstract

本文旨在提供一种新的变分推断(VI)算法用来做梯度下降(优化)。通过优化两个分布的KL散度,粒子不断迭代,使之从初始的分布一步步达到目标分布。算法涉及到Stein特征平滑转换以及核差异(KSD)。
Note:

  1. 关于KSD的介绍,可以参考我的另一篇论文笔记

1 Introduction

贝叶斯推断是一种模型估计的方法,核心就是那个贝叶斯公式。该公式中有一个归一化因子 p ( x ) = ∫ p ( x ) d x p(x)=\int p(x)\mathrm{d}x p(x)=p(x)dx,当高维问题时候,计算将变得困难。因此MCMC和VI这两种方法就应运而生,用于解决这个计算复杂度问题。
马尔科夫链蒙特卡洛(MCMC)的缺陷在于收敛慢且难收敛,优点是偏差较小,毕竟是个需要采样的方法(强化学习最初的MC方法也有这2个特点)。
变分推断(VI)是一种优化算法,用随机梯度下降最小化目标分布与估计分布之间的KL散度,其优点是优化速度快适合于大数据,缺陷是存在较大的偏差,优化结果的好坏与否取决于最先定义的数簇(如经典的平均场变分数簇)偏差大小。
最大后验概率(MAP)是一种依靠先验知识的计算后验概率的方法,这是一种MAP贝叶斯,并不是全贝叶斯。
作者提出一种新型通用的变分推断算法——Stein变分梯度下降。可以看成是一种用于贝叶斯推断的梯度下降方法。它有以下特点:

  1. 使用粒子去做分布的估计。
  2. 是一种梯度下降算法,可以用在任何使用SGD做优化的地方。
  3. 是一种最小化目标分布和估计分布间KL散度的算法,不断地驱动粒子去拟合(fit)目标分布。
  4. 算法引入了平滑转换、KSD(Stein方法+RKHS)
  5. 是一种全贝叶斯算法。
  6. 当使用单个粒子的时候,该算法就相当于对MAP做梯度下降;当使用所有的粒子的时候,就是全贝叶斯算法。
  7. 是一种超越经典VI的新VI算法。
  8. Stein梯度下降最大的贡献在于:它在再生核希尔伯特空间下给出了使得KL散度下降最快的确定性方向

2 Background

这里是对贝叶斯中的后验分布 p ( x ) = p ˉ ( x ) / Z p(x)=\bar{p}(x)/Z p(x)=pˉ(x)/Z以及KSD的介绍。KSD有个优点在于通过使用Stein得分函数 ∇ x log ⁡ p ( x ) \nabla_x\log p(x) xlogp(x)来避免计算归一化因子 Z Z Z

3 Variational Inference Using Smooth Transforms

VI做模型估计
目标分布 p ( x ) p(x) p(x),估计分布 q ( x ) q(x) q(x),数簇 Q = { q ( x ) } \mathcal{Q}=\{q(x)\} Q={q(x)},通过优化KL散度来得到估计模型 q ∗ q^* q
q ∗ = arg min ⁡ q ∈ Q { K L ( q ∣ ∣ p ) = E q [ log ⁡ q ( x ) ] − E q [ log ⁡ p ˉ ( x ) ] + log ⁡ Z } (4) q^*=\argmin_{q\in\mathcal{Q}}\{KL(q||p)=\mathbb{E}_q[\log q(x)]-\mathbb{E}_q[\log\bar{p}(x)]+\log Z\}\tag{4} q=qQargmin{KL(qp)=Eq[logq(x)]Eq[logpˉ(x)]+logZ}(4)
VI就是通过这种方式来免除归一化因子的计算。问题就是这个数簇 Q \mathcal{Q} Q很难去选择一个合适通用的。
作者指出,合适的 Q \mathcal{Q} Q应该具备以下原则:

  1. 准确性:广度足够以至于可以近似一系列目标分布。
  2. 计算可行性:容易去计算,不要整个难以实现的。
  3. 可解决性:可以和KL散度优化结合起来。

在此基础上,作者引入平滑变换 T : X → X , X ⊂ R d T:\mathcal{X}\to\mathcal{X},\mathcal{X}\subset\mathbb{R}^d T:XX,XRd,也就是说将 Q \mathcal{Q} Q看成一个一对一平滑变换 z = T ( x ) , x ∈ X z=T(x),x\in\mathcal{X} z=T(x),xX得到的数簇,且 x ∼ q 0 ( x ) x\sim q_0(x) xq0(x)
Note:

  1. 一对一的原因在于:①必须保证是函数。②如果是多对一的话,就不能保证是增量转换(单调)了,比如如果T是个 y = x 2 y=x^2 y=x2,那么当你粒子 x x x经过好几轮平滑转化之后, q q q已经发生改变了,这时候你从 q q q采样的 x x x经过平滑转换可能就会回到之前发生过的某个 q q q分布了,这是不允许的,因为这样就有点回环曲折的意思,不能保证最快的优化了。

那么经过 T T T之后的数簇是怎么样的呢?利用坐标变换公式可得:
q [ T ] ( z ) = q ( x ) ⋅ ∣ det ⁡ ( ∂ x ∂ z ) ∣ = q ( T − 1 ( z ) ) ⋅ ∣ det ⁡ ( ∇ z T − 1 ( z ) ) ∣ q_{[T]}(z)=q(x)\cdot|\det(\frac{\partial x}{\partial z})|=q(T^{-1}(z))\cdot|\det(\nabla_zT^{-1}(z))| q[T](z)=q(x)det(zx)=q(T1(z))det(zT1(z))其中 det ⁡ ( ∇ z T − 1 ( z ) ) \det(\nabla_zT^{-1}(z)) det(zT1(z))是矩阵 T − 1 ( z ) T^{-1}(z) T1(z)雅克比行列式
或者反过来:
q [ T − 1 ] ( x ) = q ( T ( x ) ) ⋅ ∣ det ⁡ ( ∇ x T ( x ) ) ∣ q_{[T^{-1}]}(x) = q(T(x))\cdot|\det(\nabla_xT(x))| q[T1](x)=q(T(x))det(xT(x))作者指出平滑变化的引入使得准确性和计算可行性满足了,可是可解决性还不能满足,因此一种方法是将 T T T参数化,但参数化将面临3个问题:

  1. 需要选择合适的参数表示来平衡三个原则。
  2. 参数化表示不一定能满足T的一对一性。
  3. 参数化不一定使得Jacobian可以得到有效计算。

参数化既然这么麻烦,作者索性放弃选用了一种不需要参数化表示 T T T的方法——一种迭代式构建增量变化的方法,在RKHS下,可表现出最快的梯度下降方向:

  1. 该方法不需要计算雅克比矩阵。
  2. 形式上简单,类似于经典的梯度下降算法,可计算性强。

3.1 Stein Operator as the Derivative of KL Divergence

作者提出了转换的表达式: T ( x ) = x + ϵ ⋅ ϕ ( x ) T(x)=x+\epsilon\cdot\phi(x) T(x)=x+ϵϕ(x)
Note:

  1. 这个表达式长得很像梯度下降吧,意味着 ϕ ( x ) \phi(x) ϕ(x)代表着方向。
  2. ϵ \epsilon ϵ是一个很小很小的数。
  3. ϕ ( x ) \phi(x) ϕ(x)可微,是一个向量。
  4. ∣ ϵ ∣ |\epsilon| ϵ很小的时候。 T ( x ) ≈ x T(x)\approx x T(x)x,所以其雅可比矩阵 ∇ x T − 1 \nabla_xT^{-1} xT1就是个单位阵,显然雅克比行列式不等于0,根据反函数理论,保证了 T T T的单调性,也就保证了一对一。

接下来就是本文最重要的地方之一了。

Lemma1
如果 X ∈ X , X ⊂ R X\in\mathcal{X},\mathcal{X}\subset\mathbb{R} XX,XR F F F是平滑变换,则:
∂ det ⁡ ( F ( X ) ) ∂ F ( X ) T = det ⁡ ( F ( X ) ) ⋅ F − 1 ( X ) \frac{\partial{\det(F(X))}}{\partial{F(X)}^T}=\det(F(X))\cdot F^{-1}(X) F(X)Tdet(F(X))=det(F(X))F1(X)证明如下:
在这里插入图片描述
Note:

  1. 求导对象 x x x必须是方阵

Lemma2
如果