理解Stein变分梯度下降的论文笔记
论文地址:点这里。作者还提供了Stein变分梯度下降法的源码。
Note:
源码不涉及深度学习,所以PyTorch用户或者TF用户都可以使用。
Stein变分梯度下降(SVGD)可以理解是一种和随机梯度下降(SGD)一样的优化算法。在强化学习算法中,Soft-Q-Learning使用了SVGD去优化,而Soft-AC选择了SGD去做优化。
Stein Variational Gradient Descent:A General Purpose Bayesian Inference Algorithm
- Abstract
- 1 Introduction
- 2 Background
- 3 Variational Inference Using Smooth Transforms
- 3.1 Stein Operator as the Derivative of KL Divergence
- 3.2 Stein Variational Gradient Descent
- 4 Related Works
- 5 Experiments
- 6 Conclusion
- 源码解析
Abstract
本文旨在提供一种新的变分推断(VI)算法
用来做梯度下降(优化)。通过优化两个分布的KL散度,粒子
不断迭代,使之从初始的分布一步步达到目标分布。算法涉及到Stein特征
的平滑
转换以及核差异(KSD
)。
Note:
- 关于KSD的介绍,可以参考我的另一篇论文笔记。
1 Introduction
贝叶斯推断是一种模型估计的方法,核心就是那个贝叶斯公式。该公式中有一个归一化因子
p
(
x
)
=
∫
p
(
x
)
d
x
p(x)=\int p(x)\mathrm{d}x
p(x)=∫p(x)dx,当高维问题时候,计算将变得困难。因此MCMC和VI这两种方法就应运而生,用于解决这个计算复杂度问题。
马尔科夫链蒙特卡洛(MCMC)的缺陷在于收敛慢且难收敛,优点是偏差较小,毕竟是个需要采样的方法(强化学习最初的MC方法也有这2个特点)。
变分推断(VI)是一种优化算法,用随机梯度下降最小化目标分布与估计分布之间的KL散度,其优点是优化速度快适合于大数据,缺陷是存在较大的偏差,优化结果的好坏与否取决于最先定义的数簇(如经典的平均场变分数簇)偏差大小。
最大后验概率(MAP)是一种依靠先验知识的计算后验概率的方法,这是一种MAP贝叶斯,并不是全贝叶斯。
作者提出一种新型通用的变分推断算法——Stein变分梯度下降。可以看成是一种用于贝叶斯推断的梯度下降方法。它有以下特点:
- 使用粒子去做分布的估计。
- 是一种梯度下降算法,可以用在任何使用SGD做优化的地方。
- 是一种最小化目标分布和估计分布间KL散度的算法,不断地驱动粒子去拟合(fit)目标分布。
- 算法引入了平滑转换、KSD(Stein方法+RKHS)。
- 是一种全贝叶斯算法。
- 当使用单个粒子的时候,该算法就相当于对MAP做梯度下降;当使用所有的粒子的时候,就是全贝叶斯算法。
- 是一种超越经典VI的新VI算法。
- Stein梯度下降最大的贡献在于:它在再生核希尔伯特空间下给出了使得KL散度下降最快的确定性方向。
2 Background
这里是对贝叶斯中的后验分布 p ( x ) = p ˉ ( x ) / Z p(x)=\bar{p}(x)/Z p(x)=pˉ(x)/Z以及KSD的介绍。KSD有个优点在于通过使用Stein得分函数 ∇ x log p ( x ) \nabla_x\log p(x) ∇xlogp(x)来避免计算归一化因子 Z Z Z。
3 Variational Inference Using Smooth Transforms
VI做模型估计:
目标分布
p
(
x
)
p(x)
p(x),估计分布
q
(
x
)
q(x)
q(x),数簇
Q
=
{
q
(
x
)
}
\mathcal{Q}=\{q(x)\}
Q={q(x)},通过优化KL散度来得到估计模型
q
∗
q^*
q∗:
q
∗
=
arg min
q
∈
Q
{
K
L
(
q
∣
∣
p
)
=
E
q
[
log
q
(
x
)
]
−
E
q
[
log
p
ˉ
(
x
)
]
+
log
Z
}
(4)
q^*=\argmin_{q\in\mathcal{Q}}\{KL(q||p)=\mathbb{E}_q[\log q(x)]-\mathbb{E}_q[\log\bar{p}(x)]+\log Z\}\tag{4}
q∗=q∈Qargmin{KL(q∣∣p)=Eq[logq(x)]−Eq[logpˉ(x)]+logZ}(4)
VI就是通过这种方式来免除归一化因子的计算。问题就是这个数簇
Q
\mathcal{Q}
Q很难去选择一个合适通用的。
作者指出,合适的
Q
\mathcal{Q}
Q应该具备以下原则:
- 准确性:广度足够以至于可以近似一系列目标分布。
- 计算可行性:容易去计算,不要整个难以实现的。
- 可解决性:可以和KL散度优化结合起来。
在此基础上,作者引入平滑变换
T
:
X
→
X
,
X
⊂
R
d
T:\mathcal{X}\to\mathcal{X},\mathcal{X}\subset\mathbb{R}^d
T:X→X,X⊂Rd,也就是说将
Q
\mathcal{Q}
Q看成一个一对一
平滑变换
z
=
T
(
x
)
,
x
∈
X
z=T(x),x\in\mathcal{X}
z=T(x),x∈X得到的数簇,且
x
∼
q
0
(
x
)
x\sim q_0(x)
x∼q0(x)。
Note:
- 一对一的原因在于:①必须保证是函数。②如果是多对一的话,就不能保证是增量转换(单调)了,比如如果T是个 y = x 2 y=x^2 y=x2,那么当你粒子 x x x经过好几轮平滑转化之后, q q q已经发生改变了,这时候你从 q q q采样的 x x x经过平滑转换可能就会回到之前发生过的某个 q q q分布了,这是不允许的,因为这样就有点回环曲折的意思,不能保证最快的优化了。
那么经过
T
T
T之后的数簇是怎么样的呢?利用坐标变换公式可得:
q
[
T
]
(
z
)
=
q
(
x
)
⋅
∣
det
(
∂
x
∂
z
)
∣
=
q
(
T
−
1
(
z
)
)
⋅
∣
det
(
∇
z
T
−
1
(
z
)
)
∣
q_{[T]}(z)=q(x)\cdot|\det(\frac{\partial x}{\partial z})|=q(T^{-1}(z))\cdot|\det(\nabla_zT^{-1}(z))|
q[T](z)=q(x)⋅∣det(∂z∂x)∣=q(T−1(z))⋅∣det(∇zT−1(z))∣其中
det
(
∇
z
T
−
1
(
z
)
)
\det(\nabla_zT^{-1}(z))
det(∇zT−1(z))是矩阵
T
−
1
(
z
)
T^{-1}(z)
T−1(z)的雅克比行列式。
或者反过来:
q
[
T
−
1
]
(
x
)
=
q
(
T
(
x
)
)
⋅
∣
det
(
∇
x
T
(
x
)
)
∣
q_{[T^{-1}]}(x) = q(T(x))\cdot|\det(\nabla_xT(x))|
q[T−1](x)=q(T(x))⋅∣det(∇xT(x))∣作者指出平滑变化的引入使得准确性和计算可行性满足了,可是可解决性还不能满足,因此一种方法是将
T
T
T参数化,但参数化将面临3个问题:
- 需要选择合适的参数表示来平衡三个原则。
- 参数化表示不一定能满足T的一对一性。
- 参数化不一定使得Jacobian可以得到有效计算。
参数化既然这么麻烦,作者索性放弃选用了一种不需要参数化表示 T T T的方法——一种迭代式构建增量变化的方法,在RKHS下,可表现出最快的梯度下降方向:
- 该方法不需要计算雅克比矩阵。
- 形式上简单,类似于经典的梯度下降算法,可计算性强。
3.1 Stein Operator as the Derivative of KL Divergence
作者提出了转换的表达式:
T
(
x
)
=
x
+
ϵ
⋅
ϕ
(
x
)
T(x)=x+\epsilon\cdot\phi(x)
T(x)=x+ϵ⋅ϕ(x)。
Note:
- 这个表达式长得很像梯度下降吧,意味着 ϕ ( x ) \phi(x) ϕ(x)代表着方向。
- ϵ \epsilon ϵ是一个很小很小的数。
- ϕ ( x ) \phi(x) ϕ(x)可微,是一个向量。
- 当 ∣ ϵ ∣ |\epsilon| ∣ϵ∣很小的时候。 T ( x ) ≈ x T(x)\approx x T(x)≈x,所以其雅可比矩阵 ∇ x T − 1 \nabla_xT^{-1} ∇xT−1就是个单位阵,显然雅克比行列式不等于0,根据反函数理论,保证了 T T T的单调性,也就保证了一对一。
接下来就是本文最重要的地方之一了。
Lemma1
如果
X
∈
X
,
X
⊂
R
X\in\mathcal{X},\mathcal{X}\subset\mathbb{R}
X∈X,X⊂R,
F
F
F是平滑变换,则:
∂
det
(
F
(
X
)
)
∂
F
(
X
)
T
=
det
(
F
(
X
)
)
⋅
F
−
1
(
X
)
\frac{\partial{\det(F(X))}}{\partial{F(X)}^T}=\det(F(X))\cdot F^{-1}(X)
∂F(X)T∂det(F(X))=det(F(X))⋅F−1(X)证明如下:
Note:
- 求导对象
x
x
x必须是
方阵
。
Lemma2
如果
推荐阅读