Skip to main content

Reparameterization Trick

在 VAE 中,我们要对 ELBO 求关于 ϕ\boldsymbol{\phi} 的偏导数 Eqϕ(zx)[f(z)]\mathbb{E}_{q_{\boldsymbol{\phi}}(\bold{z}|\bold{x})} [\boldsymbol{f}(\bold{z})]z\bold{z} 是连续随机变量, z\bold{z} given x\bold{x} 的条件密度函数为 gϕ(zx)g_{\boldsymbol{\phi}}(\bold{z|x})。在 VAE 中 z\bold{z} 作为 encoder 的输出,其具有灵活性。

Change of variables

我们可以认为 z\bold{z} 是在 x\bold{x}ϕ\boldsymbol{\phi} 给定情况下,由另外一个确定的随机变量 ϵp(ϵ)\boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon}) 经过可微和可逆的变换 g\bold{g} 得到的,

z=gϕ(ϵ,x)\bold{z} = \bold{g}_{\boldsymbol{\phi}}(\boldsymbol{\epsilon}, \bold{x})

ϵ\boldsymbol{\epsilon}x\bold{x}ϕ\boldsymbol{\phi} 是无关的。

忽略 x\bold{x}ϕ\boldsymbol{\phi} ,因为它们都是给定的。则

z=g(ϵ)=(g1(ϵ1,...,ϵn),...,gn(ϵ1,...,ϵn))\bold{z} = \bold{g}(\boldsymbol{\epsilon}) = (g_1(\epsilon_1, ..., \epsilon_n), ..., g_n(\epsilon_1, ..., \epsilon_n))

由于 g\bold{g} 可逆,因此存在逆映射

ϵ=(ϵ1,...,ϵn)=(h1(z1,...,zn),...,hn(z1,...,zn))=h(z)\boldsymbol{\epsilon}=(\epsilon_1, ..., \epsilon_n) = (h_1(z_1, ..., z_n), ..., h_n(z_1, ..., z_n))=\bold{h}(\bold{z})

因此我们有 z=g(h(z))\bold{z} = \bold{g} (\bold{h}(\bold{z})),根据链式法则:

zz=ghhz=gϵhz\frac{\partial \bold{z}}{\partial \bold{z}} = \frac{\partial \bold {g}}{\partial \bold{h}} \frac{\partial \bold {h}}{\partial \bold{z}} = \frac{\partial \bold {g}}{\partial \boldsymbol{\epsilon}} \frac{\partial \bold {h}}{\partial \bold{z}}

等式左边是单位矩阵 I\boldsymbol{I},我们对两边取行列式,则有

1=det(I)=det(gϵhz)=det(gϵ)det(hz)1 = \det (\boldsymbol{I}) = \det (\frac{\partial \bold {g}}{\partial \boldsymbol{\epsilon}} \frac{\partial \bold {h}}{\partial \bold{z}}) = \det \Big(\frac{\partial \bold {g}}{\partial \boldsymbol{\epsilon}}\Big) · \det \Big(\frac{\partial \bold {h}}{\partial \bold{z}}\Big)

qϕ(zx)q_{\boldsymbol{\phi}}(\bold{z|x}) 是随机变量函数的概率密度函数,则有

qϕ(zx)=p(ϵ)det(hz)q_{\boldsymbol{\phi}}(\bold{z|x}) = p(\boldsymbol{\epsilon}) \Big| \det \Big( \frac{\partial \bold {h}}{\partial \bold{z}} \Big)\Big|

dzd\bold{z} 做变量替换 z=gϕ(ϵ,x)\bold{z} = \bold{g}_{\boldsymbol{\phi}}(\boldsymbol{\epsilon}, \bold{x}),则有

dz=det(gϵ)dϵd\bold{z} = \Big| \det \Big( \frac{\partial \bold {g}}{\partial \boldsymbol{\epsilon}}\Big) \Big| d \boldsymbol{\epsilon}

因此

qϕ(zx)dz=p(ϵ)det(gϵ)det(hz)=p(ϵ)dϵq_{\boldsymbol{\phi}}(\bold{z|x}) d\bold{z} = p(\boldsymbol{\epsilon})\Big| \det \Big( \frac{\partial \bold {g}}{\partial \boldsymbol{\epsilon}}\Big) \Big|·\Big| \det \Big( \frac{\partial \bold {h}}{\partial \bold{z}} \Big)\Big| = p(\boldsymbol{\epsilon}) d \boldsymbol{\epsilon}

Monte Carlo estimte under change of variables

在使用了上面的变量替换之后,对于任何一个随机变量函数 f(z)\boldsymbol{f}(\bold{z}),我们有

Eqϕ(zx)[f(z)]=f(z)qϕ(zx)dz=f(gϕ(ϵ,x))p(ϵ)dϵ=Ep(ϵ)[f(z)]\begin{aligned} \mathbb{E}_{q_{\boldsymbol{\phi}}(\bold{z}|\bold{x})} [\boldsymbol{f}(\bold{z})] &= \int \boldsymbol{f}(\bold{z}) q_{\boldsymbol{\phi}}(\bold{z}|\bold{x}) d \bold{z}\\ &= \int \boldsymbol{f}(\bold{g}_{\boldsymbol{\phi}}(\boldsymbol{\epsilon}, \bold{x})) p(\boldsymbol{\epsilon}) d \boldsymbol{\epsilon}\\ &= \mathbb{E}_{p(\boldsymbol{\epsilon})} [\boldsymbol{f}(\bold{z})] \end{aligned}

期望的 Monte Carlo estimte:

Eqϕ(zx)[f(z)]=Ep(ϵ)[f(z)]=Ep(ϵ)[f(z)]1Ll=1Lf(gϕ(ϵ(l),x))\begin{aligned} \mathbb{E}_{q_{\boldsymbol{\phi}}(\bold{z}|\bold{x})} [\boldsymbol{f}(\bold{z})] &= \mathbb{E}_{p(\boldsymbol{\epsilon})} [\boldsymbol{f}(\bold{z})]\\ &= \mathbb{E}_{p(\boldsymbol{\epsilon})}[\boldsymbol{f}(\bold{z})] \\ &\simeq \frac{1}{L} \sum_{l=1}^L \boldsymbol{f}(\bold{g}_{\boldsymbol{\phi}}(\boldsymbol{\epsilon}^{(l)}, \bold{x})) \end{aligned}

其中 ϵ(l)p(ϵ)\boldsymbol{\epsilon}^{(l)} \sim p(\boldsymbol{\epsilon}).

如上所示,利用 reparameterization 可以用来重写关于 qϕ(zx)q_{\boldsymbol{\phi}}(\bold{z|x}) 的期望,使得对期望的 Monte Carlo estimate 关于 ϕ\boldsymbol{\phi} 可微。

Example

单变量高斯分布 zp(zx)=N(μ,σ2)z\sim p(z|x) = \mathcal{N}(\mu, \sigma^2),其中 μ\muσ\sigmaxxϕ\phi 决定,根据上述对变换 g\bold{g} 的要求,令 z=μ+σϵz = \mu + \sigma \epsilon,其中 ϵN(0,1)\epsilon \sim \mathcal{N}(0, 1),则

EN(μ,σ2)[f(z)]=EN(μ,σ2)[f(μ+σϵ)]l=1Lf(μ+σϵ(l))\begin{aligned} \mathbb{E}_{\mathcal{N}(\mu, \sigma^2)}[f(z)] &= \mathbb{E}_{\mathcal{N}(\mu, \sigma^2)}[f(\mu + \sigma \epsilon)] \\ &\simeq \sum_{l=1}^L f(\mu + \sigma \epsilon^{(l)}) \\ \end{aligned}

其中 ϵ(l)N(0,1)\epsilon^{(l)} \sim \mathcal{N}(0, 1).