はじめに
前回の投稿では、論文Adversarial Variational Bayesの前半部分をまとめた。今回は後半部分をまとめる。
前半部分の結論
従来のVAEの目的関数は \begin{equation} (\theta^{*},\phi^{*})=\arg\max_{\theta,\phi} {\rm E}_{p_{d}(x)}\left[ -{\rm KL} \left( q_{\phi}(z|x), p(z) \right) + {\rm E}_{q_{\phi}(z|x)} \left[\ln{p_{\theta}(x|z)}\right] \right] \label{eq1} \end{equation}
である。本論文の前半部分では式(\ref{eq1})から次の2つの目的関数を導出した。
\begin{eqnarray}
T^{*}(x,z)&=&\arg{\max_{T}}\;
{\rm E}_{p_{d}(x)}
\left[
{\rm E}_{q_{\phi}(z|x)}
\left[
\ln{\sigma\left(T(x, z)\right)}
\right]
+
{\rm E}_{p(z)}
\left[
\ln{\left(1-\sigma\left(T(x, z)\right)\right)}
\right]
\right]
\label{eq2}\\
(\theta^{*},\phi^{*})&=&\arg\max_{\theta,\phi}
{\rm E}_{p_{d}(x)}\;{\rm E}_{q_{\phi}(z|x)}
\left[
\ln{p_{\theta}(x|z)}-T^{*}(x,z)
\right]
\label{eq3}
\end{eqnarray}
ここで、\sigma(\cdot)はシグモイド関数である。実際に計算を行うと、式(\ref{eq2})から最適な値T^*(x,z)を求めることは難しいことが分かる。その原因は、式(\ref{eq1})の右辺にあるKullback Leibler divergence
にある。この量は、事後分布q_{\phi}(z|x)の形状を事前分布p(z)の形状に近づけようとする。一般にこれら2つの分布は大きく異なる形状を持つため、q_{\phi}(z|x)\approx p(z)の実現は難しい。また、事後分布を事前分布に近づける操作は、事後分布から観測値xの依存性を失くすことに相当するため、Bayes推論の立場から考えても妥当ではない。本論文の後半部分ではこの操作を改善する。
後半部分の内容
事後分布q_{\phi}(z|x)に良く似た形状を持ち、かつ、解析計算が可能な補助分布(auxiliary distribution)r_{\alpha}(z|x)を導入する。例えば、r_{\alpha}(z|x)として正規分布を考えることができる。この補助分布を用いて、式(\ref{eq1})の右辺を次のように書き換える。 \begin{equation} {\rm E}_{p_{d}(x)}\left[ -{\rm KL} \left( q_{\phi}(z|x), r_{\alpha}(z|x) \right) + {\rm E}_{q_{\phi}(z|x)} \left[ -\ln{r_{\alpha}(z|x)} +\ln{p_{\theta}(x,z)}\right] \right] \label{eq4} \end{equation}
r_{\alpha}(z|x)をq_{\phi}(z|x)を近似する関数と置いたので、式(\ref{eq4})に現れるKullback Leibler divergenceの値は、式(\ref{eq1})のそれよりもずっと容易に小さな値とすることができる。
いま、補助分布として次の正規分布を仮定する。 \begin{equation} r_{\alpha}(z|x)=\mathcal{N}(z|\mu(x),\sigma_s^2(x)) \end{equation}
これは、変数変換
\begin{equation}
\tilde{z}=\frac{z-\mu(x)}{\sigma_s(x)}
\end{equation}
の下で
\begin{equation}
r_{\alpha}(z|x)=\frac{1}{\sigma_s}\mathcal{N}(\tilde{z}|0,1)\equiv \frac{1}{\sigma_s}r_0(\tilde{z})
\end{equation}
となるから
\begin{eqnarray}
{\rm KL}
\left(
q_{\phi}(z|x),
r_{\alpha}(z|x)
\right)
&=&
\int dz\;q_{\phi}(z|x)\ln{\frac{q_{\phi}(z|x)}{r_{\alpha}(z|x)}}
\label{eq5} \\
&=&
\int d\tilde{z}\;\sigma_s q_{\phi}(\sigma_s \tilde{z}+\mu|x)
\ln{
\frac
{\sigma_s q_{\phi}(\sigma_s\tilde{z}+\mu|x)}
{r_0(\tilde{z})}
}\\
&=&
{\rm KL}
\left(
\tilde{q}_{\phi}(\tilde{z}|x),
r_0(\tilde{z})
\right)
\label{eq6}
\end{eqnarray}
となる。ただし、\tilde{q}_{\phi}(\tilde{z}|x)\equiv \sigma_s q_{\phi}(\sigma_s\tilde{z}+\mu|x)とした。\tilde{q}_{\phi}(\tilde{z}|x)は
\begin{equation}
\int d\tilde{z}\tilde{q}_{\phi}(\tilde{z}|x)=1
\end{equation}
を満たす新たな分布である。さらに、次式が成り立つ。
\begin{eqnarray}
{\rm E}_{q_{\phi}(z|x)}
\left[
\ln{r_{\alpha}(z|x)}
\right]
&=&
\int dz\;q_{\phi}(z|x)\ln{r_{\alpha}(z|x)} \\
&=&
\int d\tilde{z}\;\tilde{q}_{\phi}(\tilde{z}|x)\ln{r_{0}(\tilde{z})}-\ln{\sigma_s} \\
&=&
{\rm E}_{\tilde{q}_{\phi}(\tilde{z}|x)}
\left[
\ln{r_{0}(\tilde{z})}
\right]+{\rm const.}
\end{eqnarray}
以上から、式(\ref{eq4})は、定数項を無視して、次式に変換される。
\begin{equation}
{\rm E}_{p_{d}(x)}\left[
-{\rm KL}
\left(
\tilde{q}_{\phi}(\tilde{z}|x),
r_{0}(\tilde{z})
\right)
+
{\rm E}_{\tilde{q}_{\phi}(\tilde{z}|x)}
\left[
-\ln{r_{0}(\tilde{z})}
\right]+
{\rm E}_{q_{\phi}(z|x)}
\left[
\ln{p_{\theta}(x,z)}\right]
\right]
\label{eq7}
\end{equation}
ここで、-{\rm KL}
\left(
\tilde{q}_{\phi}(\tilde{z}|x),
r_{0}(\tilde{z})
\right)に注目して、論文の前半部分で用いた議論を繰り返し、T(x,\tilde{z})を導入すると、式(\ref{eq7})は次の2式と等価となる。
\begin{eqnarray}
T^{*}(x,\tilde{z})&=&\arg{\max_{T}}
\;{\rm E}_{p_{d}(x)}
\left[
{\rm E}_{\tilde{q}_{\phi}(\tilde{z}|x)}
\left[
\ln{\sigma\left(T(x, \tilde{z})\right)}
\right]
+
{\rm E}_{r_0(\tilde{z})}
\left[
\ln{\left(1-\sigma\left(T(x, \tilde{z})\right)\right)}
\right]
\right]
\label{eq8}\\
(\theta^{*},\phi^{*})
&=&
\arg\max_{\theta,\phi}
\;{\rm E}_{p_{d}(x)}
\left[
{\rm E}_{q_{\phi}(z|x)}
\left[
\ln{p_{\theta}(x,z)}
\right]
+
{\rm E}_{\tilde{q}_{\phi}(\tilde{z}|x)}
\left[
-T^*(x,\tilde{z})-\ln{r_0(\tilde{z})}
\right]
\right]
\label{eq9}
\end{eqnarray}
これらに、再パラメータ化トリックを適用する。すなわち、
\begin{equation}
z\sim q_{\phi}(z|x)
\end{equation}
を
\begin{eqnarray}
\epsilon&\sim&p(\epsilon)\\
z&=&z_{\phi}(x,\epsilon)
\end{eqnarray}
に置き換える。このとき
\begin{equation}
\tilde{z}\sim \tilde{q}_{\phi}(\tilde{z}|x)
\end{equation}
は
\begin{eqnarray}
\tilde{z}&=&\frac{z_{\phi}(x,\epsilon)-\mu}{\sigma_s}\equiv\tilde{z}_{\phi}(x,\epsilon)
\end{eqnarray}
に置き換わる。また、分布r_0(\tilde{z})は標準正規分布であるから
\begin{equation}
\tilde{z}
\sim r_0(\tilde{z})
\end{equation}
は
\begin{eqnarray}
\eta&\sim&\mathcal{N}(\eta|0,1)\equiv p(\eta) \\
\tilde{z}&=&\eta
\end{eqnarray}
と書くことができる。以上から式(\ref{eq8}),(\ref{eq9})は次式に変形される。
\begin{eqnarray}
T^{*}(x,\tilde{z})
&=&
\arg{\max_{T}}
\;{
\rm E}_{p_{d}(x)}
\Bigl[
{\rm E}_{p(\epsilon)}
\left[
\ln{\sigma\left(T(x, \tilde{z}_{\phi}(x,\epsilon))\right)}
\right]
+
{\rm E}_{p(\eta)}
\left[
\ln{\left(1-\sigma\left(T(x, \eta)\right)\right)}
\right]
\Bigr]
\label{eq10}\\
(\theta^{*},\phi^{*})
&=&
\arg\max_{\theta,\phi}
\;{\rm E}_{p_{d}(x)}
\Bigl[
{\rm E}_{p(\epsilon)}
\left[
\ln{p_{\theta}(x,z_{\phi}(x,\epsilon))}
\right]
+
{\rm E}_{p(\epsilon)}
\left[
-T^*(x,\tilde{z}_{\phi}(x,\epsilon))-\ln{r_0(\tilde{z}_{\phi}(x,\epsilon))}
\right]
\Bigr]
\label{eq11}
\end{eqnarray}
r_0は標準正規分布であるから、式(\ref{eq11})の最後の項は
\begin{equation}
-\ln{r_0(\tilde{z}_{\phi}(x,\epsilon))}=\frac{1}{2}\|\tilde{z}_{\phi}(x,\epsilon)\|^2+{\rm const.}
\end{equation}
となる。最適化を行う際は定数項は無視すれば良い。また、p_{\theta}(x,z)の計算は
\begin{equation}
p_{\theta}(x,z)=p_{\theta}(x|z)p(z)
\end{equation}
を利用すれば良い。ここまでの議論を反映したアルゴリズムが論文に掲載されている。