[解读] 浅析变分自编码器 (VAE)

    技术2022-07-10  72

    变分自编码器 (VAE)

    VAE 开篇论文:

    (1312) Auto-Encoding Variational Bayes

    (1401)Stochastic Backpropagation and Approximate Inference in Deep Generative Models

    VAE 教程 (1606) Tutorial on Variational Autoencoders

    VAE 综述 (1906) An Introduction to Variational Autoencoders

    变分推理 Variational Inference

    博文解读: https://blog.csdn.net/jackytintin/article/details/53641885

    https://blog.csdn.net/weixin_40255337/article/details/83088786

    研究背景

    在机器学习领域, 我们对学习数据的概率模型非常感兴趣, 概率模型对于未知预测和各种形式的辅助和自动决策有很大帮助.

    假设观察变量 x \mathbf{x} x 服从于一个未知的真实分布 p ⋆ ( x ) p^{\star}(\mathbf{x}) p(x), 我们希望学习分布 p θ ( x ) p_{\theta}(\mathbf{x}) pθ(x) 的参数 θ \theta θ 来逼近真实分布 p ⋆ ( x ) p^{\star}(\mathbf{x}) p(x).

    这个逼近方法有很多, 本文考虑深度隐变量模型 (DLVM).

    首先使用 z \mathbf{z} z 来代表隐变量, 有向图模型表示为一种建立于 x \mathbf{x} x z \mathbf{z} z 上的联合分布 p θ ( x , z ) p_{\theta}(\mathbf{x}, \mathbf{z}) pθ(x,z), 计算它的边际分布即可得到 p θ ( x ) = ∫ p θ ( x , z ) d z p_{\theta}(\mathbf{x})=\int p_{\theta}(\mathbf{x},\mathbf{z})d \mathbf{z} pθ(x)=pθ(x,z)dz 如果 z \mathbf{z} z 是连续型变量, 则 p θ ( x ) p_{\theta}(\mathbf{x}) pθ(x) 可以看成一个无限的混合模型, 由此可见它可以表示任意的真实分布. 如果使用神经网络来参数化 p θ ( x , z ) p_{\theta}(\mathbf{x}, \mathbf{z}) pθ(x,z), 则这个网络称为 DLVM.

    DLVM 的一个重要优点是即使有向图中的每个因子(先验或条件分布) 是相对简单的分布, p θ ( x ) p_{\theta}(\mathbf{x}) pθ(x) 可以是非常复杂的. 因此我们使用 DLVM 来逼近真实分布 p ⋆ ( x ) p^{\star}(\mathbf{x}) p(x).

    DLVM 作为一种最简单最常见的方法, 它的分解具有以下结构: p θ ( x , z ) = p θ ( z ) p θ ( x ∣ z ) p_{\theta}(\mathbf{x}, \mathbf{z})=p_{\theta}(\mathbf{z}) p_{\theta}(\mathbf{x}|\mathbf{z}) pθ(x,z)=pθ(z)pθ(xz) 其中 p θ ( z ) p_{\theta}(\mathbf{z}) pθ(z) 和 $ p_{\theta}(\mathbf{x}|\mathbf{z})$ 至少一个是被指定的, p θ ( z ) p_{\theta}(\mathbf{z}) pθ(z) z \mathbf{z} z 的先验分布.

    DLVM 可以通过最大似然来学习, 但由于计算 p θ ( x , z ) p_{\theta}(\mathbf{x},\mathbf{z}) pθ(x,z) 的积分没有分析解或者有效的估计, 所以无法通过求微分的方式来优化它.

    计算 p θ ( x ) p_{\theta}(x) pθ(x) 的难易程度与 p θ ( z ∣ x ) p_{\theta}(\mathbf{z} |\mathbf{x}) pθ(zx) 难易程度是相关联的, 换句话说, 如果 p θ ( x ) p_{\theta}(x) pθ(x) 好算,则后者也好算, 注意到联合分布 p θ ( x , z ) p_{\theta}(\mathbf{x}, \mathbf{z}) pθ(x,z) 是容易计算的, 这些密度分布有以下关系: p θ ( z ∣ x ) = p θ ( x , z ) p θ ( x ) p_{\boldsymbol{\theta}}(\mathbf{z} | \mathbf{x})=\frac{p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})}{p_{\boldsymbol{\theta}}(\mathbf{x})} pθ(zx)=pθ(x)pθ(x,z) 近似推理方法能够估计后验分布 p θ ( z ∣ x ) p_{\boldsymbol{\theta}}(\mathbf{z} | \mathbf{x}) pθ(zx) 和边际分布 p θ ( x ) {p_{\boldsymbol{\theta}}(\mathbf{x})} pθ(x). 然而传统的计算方法是比较复杂的, 例如这需要逐个样本的循环, 或者产生不好的后验近似, 我们的目的是避免这种高代价的处理过程.

    变分自编码器

    前面说到 DLVM 模型的训练问题, 变分自编码器则是一种非常高效的计算框架. 首先引入一个参数化的推理模型 q ϕ ( z ∣ x ) q_{\phi}(\mathbf{z}|\mathbf{x}) qϕ(zx), 也被称为编码器或识别模型. 称 ϕ \phi ϕ 为变分参数. 我们优化这个参数来使得 q ϕ ( z ∣ x ) ≈ p θ ( z ∣ x ) q_{\phi}(\mathbf{z} | \mathbf{x}) \approx p_{\boldsymbol{\theta}}(\mathbf{z} | \mathbf{x}) qϕ(zx)pθ(zx) 对任意一个推理模型 q q q 和参数 ϕ \phi ϕKaTeX parse error: \cr valid only within a tabular/array environment 其中第二项是 p p p q q q 之间的 KL 散度, 这是一个非负值, 当分布相同时值为0. 第一项是一个变分下界, 也被称为 evidence lower bound (ELBO), 记为 L θ , ϕ ( x ) = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ] \mathcal{L}_{\boldsymbol{\theta}, \phi}(\mathbf{x})=\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\phi}(\mathbf{z} | \mathbf{x})\right] Lθ,ϕ(x)=Eqϕ(zx)[logpθ(x,z)logqϕ(zx)]

    已知 KL 散度值是非负的, 由 L θ , ϕ ( x ) = log ⁡ p θ ( x ) − D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) ≤ log ⁡ p θ ( x ) \begin{aligned} \mathcal{L}_{\boldsymbol{\theta}, \phi}(\mathbf{x}) &=\log p_{\boldsymbol{\theta}}(\mathbf{x})-D_{K L}\left(q_{\phi}(\mathbf{z} | \mathbf{x}) \| p_{\boldsymbol{\theta}}(\mathbf{z} | \mathbf{x})\right) \\ & \leq \log p_{\boldsymbol{\theta}}(\mathbf{x}) \end{aligned} Lθ,ϕ(x)=logpθ(x)DKL(qϕ(zx)pθ(zx))logpθ(x) 可知 ELBO 是 log ⁡ p θ ( x ) \log p_{\boldsymbol{\theta}}(\mathbf{x}) logpθ(x) 的下界. 从上式可以看出, 最大化 ELBO (优化参数为 θ , ϕ \theta, \phi θ,ϕ) 将会带来两个结果: 首先这能够同时最大化似然 log ⁡ p θ ( x ) \log p_{\boldsymbol{\theta}}(\mathbf{x}) logpθ(x), 使得模型变得更好, 其次能够最小化 KL 散度.

    现在我们的目标是最大化 ELBO, 它有一个重要的特性是可以使用随机梯度下降方法来联合优化所有参数.

    ELBO 关于 θ \theta θ 的无偏梯度为 ∇ θ L θ , ϕ ( x ) = ∇ θ E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ ∇ θ ( log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ) ] ≃ ∇ θ ( log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ) = ∇ θ ( log ⁡ p θ ( x , z ) ) \begin{aligned} \nabla_{\boldsymbol{\theta}} \mathcal{L}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\mathbf{x}) &=\nabla_{\boldsymbol{\theta}} \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\phi}(\mathbf{z} | \mathbf{x})\right] \\ &=\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\nabla_{\boldsymbol{\theta}}\left(\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\phi}(\mathbf{z} | \mathbf{x})\right)\right] \\ & \simeq \nabla_{\boldsymbol{\theta}}\left(\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\phi}(\mathbf{z} | \mathbf{x})\right) \\ &=\nabla_{\boldsymbol{\theta}}\left(\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})\right) \end{aligned} θLθ,ϕ(x)=θEqϕ(zx)[logpθ(x,z)logqϕ(zx)]=Eqϕ(zx)[θ(logpθ(x,z)logqϕ(zx))]θ(logpθ(x,z)logqϕ(zx))=θ(logpθ(x,z)) 其中第四行是第二行的简单的蒙特卡罗估计, 最后两行的 z \mathbf{z} z 随机采样于 q ϕ ( z ∣ x ) q_{\phi}(\mathbf{z} | \mathbf{x}) qϕ(zx).

    然而关于 ϕ \phi ϕ 的无偏梯度比较难以获得, 这是因为与 q ϕ ( z ∣ x ) q_{\phi}(\mathbf{z} | \mathbf{x}) qϕ(zx) 有关, 微分算子无法穿过期望算子: ∇ ϕ L θ , ϕ ( x ) = ∇ ϕ E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ] ≠ E q ϕ ( z ∣ x ) [ ∇ ϕ ( log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ) ] \begin{aligned} \nabla_{\phi} \mathcal{L}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\mathbf{x}) &=\nabla_{\boldsymbol{\phi}} \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\boldsymbol{\phi}}(\mathbf{z} | \mathbf{x})\right] \\ & \neq \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\nabla_{\boldsymbol{\phi}}\left(\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\boldsymbol{\phi}}(\mathbf{z} | \mathbf{x})\right)\right] \end{aligned} ϕLθ,ϕ(x)=ϕEqϕ(zx)[logpθ(x,z)logqϕ(zx)]=Eqϕ(zx)[ϕ(logpθ(x,z)logqϕ(zx))]

    重参数技巧

    对于连续型的隐变量 z \mathbf{z} z, 我们可以用重参数的技巧来计算无偏梯度 ∇ ϕ L θ , ϕ ( x ) \nabla_{\phi} \mathcal{L}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\mathbf{x}) ϕLθ,ϕ(x). 把随机变量变量 z \mathbf{z} z 表示为另一个随机变量 ϵ ∼ p ( ϵ ) \boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon}) ϵp(ϵ) 的可微并且可逆的变换, 即 z = g ( ϵ , ϕ , x ) \mathbf{z}=\mathbf{g}(\boldsymbol{\epsilon}, \boldsymbol{\phi}, \mathbf{x}) z=g(ϵ,ϕ,x) 其中三个随机变量都是互相独立的. 给定这个变换后, ELBO 可以重写为 L θ , ϕ ( x ) = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ] = E p ( ϵ ) [ log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) ] \begin{aligned} \mathcal{L}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\mathbf{x}) &=\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\boldsymbol{\phi}}(\mathbf{z} | \mathbf{x})\right] \\ &=\mathbb{E}_{p(\boldsymbol{\epsilon})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\boldsymbol{\phi}}(\mathbf{z} | \mathbf{x})\right] \end{aligned} Lθ,ϕ(x)=Eqϕ(zx)[logpθ(x,z)logqϕ(zx)]=Ep(ϵ)[logpθ(x,z)logqϕ(zx)] 然后便可以采用简单的蒙特卡洛模拟来估计出关于单点的 ELBO, 整理如下 ϵ ∼ p ( ϵ ) z = g ( ϕ , x , ϵ ) L ~ θ , ϕ ( x ) = log ⁡ p θ ( x , z ) − log ⁡ q ϕ ( z ∣ x ) \begin{aligned} \epsilon & \sim p(\epsilon) \\ \mathbf{z} &=\mathbf{g}(\boldsymbol{\phi}, \mathbf{x}, \boldsymbol{\epsilon}) \\ \tilde{\mathcal{L}}_{\boldsymbol{\theta}, \phi}(\mathbf{x}) &=\log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z})-\log q_{\phi}(\mathbf{z} | \mathbf{x}) \end{aligned} ϵzL~θ,ϕ(x)p(ϵ)=g(ϕ,x,ϵ)=logpθ(x,z)logqϕ(zx)

    算法伪代码如下:

    只要选择一个好的 g ( ) g() g() 函数, 那么就是关于 log ⁡ q ϕ ( z ∣ x ) \log q_{\boldsymbol{\phi}}(\mathbf{z} | \mathbf{x}) logqϕ(zx) 的计算就简单了. 当 g ( ) g() g() 是一个可逆函数时, ϵ \boldsymbol{\epsilon} ϵ z \mathbf{z} z 的密度的关系为 log ⁡ q ϕ ( z ∣ x ) = log ⁡ p ( ϵ ) − log ⁡ d ϕ ( x , ϵ ) \log q_{\phi}(\mathbf{z} | \mathbf{x})=\log p(\epsilon)-\log d_{\phi}(\mathbf{x}, \epsilon) logqϕ(zx)=logp(ϵ)logdϕ(x,ϵ) 其中第二项为 log ⁡ d ϕ ( x , ϵ ) = log ⁡ ∣ det ⁡ ( ∂ z ∂ ϵ ) ∣ \log d_{\phi}(\mathrm{x}, \epsilon)=\log \left|\operatorname{det}\left(\frac{\partial \mathrm{z}}{\partial \epsilon}\right)\right| logdϕ(x,ϵ)=logdet(ϵz).

    然后构造一个灵活的 g ( ) g() g() 来使得 log ⁡ d ϕ ( x , ϵ ) \log d_{\phi}(\mathrm{x}, \epsilon) logdϕ(x,ϵ) 计算更简单, 并且得到高度灵活的推理模型 q ϕ ( z ∣ x ) q_{\boldsymbol{\phi}}(\mathbf{z} | \mathbf{x}) qϕ(zx). 常用的选择是一个简单的 factorized Gaussian encoder q ϕ ( z ∣ x ) = N ( z ; μ , diag ⁡ ( σ 2 ) ) q_{\phi}(\mathbf{z} | \mathbf{x})=\mathcal{N}\left(\mathbf{z} ; \boldsymbol{\mu}, \operatorname{diag}\left(\boldsymbol{\sigma}^{2}\right)\right) qϕ(zx)=N(z;μ,diag(σ2)): ( μ , log ⁡ σ ) =  EncoderNeuralNet  ϕ ( x ) q ϕ ( z ∣ x ) = ∏ i q ϕ ( z i ∣ x ) = ∏ i N ( z i ; μ i , σ i 2 ) \begin{aligned} (\boldsymbol{\mu}, \log \boldsymbol{\sigma}) &=\text { EncoderNeuralNet }_{\boldsymbol{\phi}}(\mathbf{x}) \\ q_{\phi}(\mathbf{z} | \mathbf{x}) &=\prod_{i} q_{\phi}\left(z_{i} | \mathbf{x}\right)=\prod_{i} \mathcal{N}\left(z_{i} ; \mu_{i}, \sigma_{i}^{2}\right) \end{aligned} (μ,logσ)qϕ(zx)= EncoderNeuralNet ϕ(x)=iqϕ(zix)=iN(zi;μi,σi2) 重参数后可以写为 ϵ ∼ N ( 0 , I ) ( μ , log ⁡ σ ) =  EncoderNeuralNet  ϕ ( x ) z = μ + σ ⊙ ϵ \begin{aligned} \epsilon & \sim \mathcal{N}(0, \mathbf{I}) \\ (\boldsymbol{\mu}, \log \boldsymbol{\sigma}) &=\text { EncoderNeuralNet }_{\phi}(\mathbf{x}) \\ \mathbf{z} &=\boldsymbol{\mu}+\boldsymbol{\sigma} \odot \boldsymbol{\epsilon} \end{aligned} ϵ(μ,logσ)zN(0,I)= EncoderNeuralNet ϕ(x)=μ+σϵ 并且有 log ⁡ d ϕ ( x , ϵ ) = log ⁡ ∣ det ⁡ ( ∂ z ∂ ϵ ) ∣ = ∑ i log ⁡ σ i \log d_{\phi}(\mathbf{x}, \epsilon)=\log \left|\operatorname{det}\left(\frac{\partial \mathbf{z}}{\partial \epsilon}\right)\right|=\sum_{i} \log \sigma_{i} logdϕ(x,ϵ)=logdet(ϵz)=ilogσi. 从而得到 log ⁡ q ϕ ( z ∣ x ) = log ⁡ p ( ϵ ) − log ⁡ d ϕ ( x , ϵ ) = ∑ i log ⁡ N ( ϵ i ; 0 , 1 ) − log ⁡ σ i \begin{aligned} \log q_{\phi}(\mathbf{z} | \mathbf{x}) &=\log p(\boldsymbol{\epsilon})-\log d_{\phi}(\mathbf{x}, \boldsymbol{\epsilon}) \\ &=\sum_{i} \log \mathcal{N}\left(\epsilon_{i} ; 0,1\right)-\log \sigma_{i} \end{aligned} logqϕ(zx)=logp(ϵ)logdϕ(x,ϵ)=ilogN(ϵi;0,1)logσi

    Processed: 0.031, SQL: 9