はじめに
このページの続きです。
やりたいこと
単語列 \begin{equation} W = \{w_1,\cdots, w_T\} \end{equation}
が与えられているとき、この単語列に含まれる任意の単語w_tを考える。w_tを含む文からその周辺の単語の集合\Xi_{w_t}を作ることができる。
\begin{equation}
\Xi_{w_t} = \{\xi_1^t, \cdots, \xi_C^t\}
\end{equation}
このとき、以下の対数尤度を考える。
\begin{eqnarray}
L&=&\sum_{t=1}^T \sum_{c=1}^{C}
\biggl\{
\log{\sigma(s(\xi_c^t,w_t))}+\sum_{w \in {\rm Ng}}\log{\sigma(-s(w,w_t))}
\biggr\} \label{obj}\\
\sigma(x) &=& \frac{1}{1+\exp{(-x)}}\\
s(a,b) &=& \vec{u}^T_a \cdot \vec{v}_b
\end{eqnarray}
ここで、{\rm Ng}は分布
\begin{equation}
p(w) = \frac{U(w)^{0.75}}{\sum_{t=1}^{T}U(w_t)^{0.75}}
\end{equation}
に従ってサンプリングしたk個の単語の集合である。U(w)は単語の出現頻度である。本ページでは、上記の対数尤度の最大化問題を、Neural Networkを用いて解くことができることを示す。
Neural Networkによる表現
単語列W内のt番目の単語w_tを以下のone-hot vector \vec{x}_tで表現する。 \begin{equation} \vec{x}_t = \left( \begin{array}{c} 0 \\ \vdots \\ 1 \\ 0 \\ \vdots \\ 0 \end{array} \right) \end{equation}
t番目の要素だけが1のT次元ベクトルである。次に、行列W_IとW_Oを次式で定義する。
\begin{eqnarray}
W_I&=&\left[\vec{v}_1, \cdots, \vec{v}_T\right] \\
W_O&=&\left(
\begin{array}{c}
\vec{u}_1^T \\
\vdots \\
\vec{u}_T^T
\end{array}
\right)
\end{eqnarray}
ここで、\vec{v}_tと\vec{u}_tはM次元ベクトルとする。従って、W_IはM\times T行列、W_OはT\times M行列である。これらを用いると
\begin{eqnarray}
(W_O\; W_I\; \vec{x}_t)_i
&=& (W_O)_{ik}\;(W_I\;\vec{x}_t)_k \\
&=&(W_O)_{ik}\;(W_I)_{km} (\vec{x}_t)_{m}\\
&=& u_{ik}\;v_{mk}\;x_{tm} \\
&=& u_{ik}\;v_{mk}\;\delta_{tm} \\
&=& u_{ik}\;v_{tk} \\
&=& \vec{u}_{i}^T \cdot \vec{v}_{t} \\
&=& s(i, t)
\end{eqnarray}
となる。これを図にすると以下のようなる。
以上から以下のことが分る。
- 式(\ref{obj})の第1項に含まれるs(\xi_c^t, w_t)を計算するには、w_tに相当するone-hot vector \vec{x}_tをネットワークに入力し、その出力ベクトル(T次元ペクトル)の成分のうち、\xi_c^tに相当するものを取り出せば良い。
- 式(\ref{obj})の第2項に含まれるs(w, w_t)の計算も同様である。
T \gg Mと取るので低次元で効率良く単語を表現できるベクトルを得ることができる。
Chainerによる実装
Chainerのソースには、word2vecのサンプルプログラムが含まれている。このサンプルプログラムには、Skip Gram以外にContinuous BoWも実装されており、さらに、Negative Sampling以外の損失関数も選択できるようになっている。以下に示すのは、サンプルプログラムを参考にして、Skip GramかつNegative Samplingの場合のクラスを実装したものである。
- 5行目で定義されるL.EmbedIDは、W_Iに相当する。
- 6行目で定義されるL.NegativeSamplingは、式(\ref{obj})を計算する。この関数の中にW_Oに相当するものがある。
- n_vocabはT、n_unitsはMに相当する。
- countsはU(w)に相当する。
- sample_sizeはkに相当する。
- 13行目の引数xはいま対象にする単語w_t、contextは\Xi_{w_t}に相当する。ただし、batch処理が考慮されている。