井出 剛著の「入門機械学習による異常検出」(以降、井出本と記す)の例題をSageを使ってお復習いします。
井出本の基本は、データに対するモデルを使って負の対数尤度を求め、それを異常検出関数として使うことです。
Davisの体重(weight)のヒストグラムを見ると、分布が左右対称ではないことに気づきます。 井出本では、ガンマ分布を使ってこの非対称な分布の異常度を検出しています。
いつものように必要なパッケージを読込ます。
|
[1] "MASS" "car" "jsonlite" "ggplot2" "stats" "graphics" "grDevices" "utils" [9] "datasets" "methods" "base" [1] "MASS" "car" "jsonlite" "ggplot2" "stats" "graphics" "grDevices" "utils" [9] "datasets" "methods" "base" |
Davisの体重のヒストグラムを再度プロットしてみます。 1個だけ160のところにありますが、2章ではこれが体重と身長を入れ間違えたものと結論づけています。
<ggplot: (8738330208357)> <ggplot: (8738330208357)> |
Saving 11.0 x 8.0 in image. Saving 11.0 x 8.0 in image. |
2章と同様にガンマ分布の対数尤度を求めてみます。
井出本では、ガンマ分布を以下の様に定義しています。 $$ \mathcal{G} (x | k, s) = \frac{1}{s \Gamma(k)}(\frac{x}{s})^{k-1} exp(- \frac{x}{s}) $$
これから対数尤度Lは、以下の様になります。 $$ L(k, s | \mathcal{D}) = \sum_{n=1}^N \left [ -ln\{s \Gamma(k) \} + (k-1) ln\frac{x^{(n)}}{s} - \frac{x^{(n)}}{s} \right ] $$
この式をパラメータkとsで微分し、0と等しいとします。
$-ln\{s \Gamma(k) \}$が$-ln(s) - ln(\Gamma(k))$となりますから、sの偏微分は以下の様になります。 $$ \begin{eqnarray} \frac{\partial L}{\partial s} & = & \sum_{n=1}^N \left [ -\frac{1}{s} + (k-1)\frac{\partial}{\partial s} ln \frac{x^{(n)}}{s} + \frac{x^{(n)}}{s^2} \right ] \\ & = & \sum_{n=1}^N \left [ -\frac{1}{s} - \frac{(k-1)}{s} + \frac{x^{(n)}}{s^2} \right ] \\ & = & \sum_{n=1}^N \left [ -\frac{k}{s} + \frac{x^{(n)}}{s^2} \right ] \end{eqnarray} $$ この式が0となることから、 $$ \hat{s} = \frac{1}{\hat{k} N} \sum_{n=1}^N x^{(n)} = \frac{\hat{\mu}}{\hat{k}} $$
同様にkの偏微分は以下の様になります。 $$ \frac{\partial L}{\partial k} = \sum_{n=1}^N \left [ -\frac{1}{\Gamma(k)}\frac{\partial \Gamma(k)}{\partial k} + ln \frac{x^{(n)}}{s}\right ] $$ しかし、kがガンマ関数が入っているので、簡単には$\hat{k}$は求まりません。
井出本では、最尤推定でk, sを求める代わりにモーメント法を使ってkとsを推定しています。 1次のモーメントは以下の様に計算されます。$\Gamma(k+1) = k \Gamma(k)$を関係を使っています。(以下の式は未フォロー) $$ \lt x \gt = \int_0^{\infty} dx \, x \, \mathcal{G}(x | k,s) = s\frac{\Gamma(k+1)}{\Gamma(k)} \int_0^{\infty} dx \, x \, \mathcal{G}(x | k+1,s) = ks $$ 同様に2次のモーメントは、以下の様になります。 $$ \lt x^2 \gt = k(k+1)s^2 $$
これとデータから求められる1次の2次モーメントの関係 $$ \lt x \gt = \frac{1}{N} \sum_{n=1}^N x^{(n)} $$ $$ \lt x^2 \gt = \frac{1}{N} \sum_{n=1}^N {x^{(n)}}^2 $$
この関係から $$ \hat{k}_{mo} = \frac{\hat{\mu}^2}{\hat{\sigma}^2}, \hat{s}_{mo} = \frac{\hat{\sigma}^2}{\hat{\mu}} $$
davisのweight(体重)でモーメント法を使ってk, sを求めてみます。
19.1928236809 3.42836474164 19.1928236809 3.42836474164 |
Rのfitdistrを使ってdavisのweight(体重)のガンマ分布を推定してみます。 shapeがkの値、rateが1/sの値として返されます。
モーメント法とfitdistrの結果は、そこそこ近い結果となっています。
shape rate 22.4854793 0.3417247 ( 2.2317648) ( 0.0342978) shape rate 22.4854793 0.3417247 ( 2.2317648) ( 0.0342978) |
shape 22.48548 rate 2.926333 shape 22.48548 rate 2.926333 |
このノートの特徴は、Sageとpythonを使って井出本と同じ結果を求めることにあります。
scipyのstatsのgamma.fitを使って、kとsを求めてみます。 gamma.fitでは、floc=0を指定しないとfitdistrと同じ結果になりません。
(22.485303592007426, 0, 2.9263558630975797) (22.485303592007426, 0, 2.9263558630975797) |
今回バージョンアップしたSage6.7からヒストグラムがサポートされました。
通常のヒストグラムと正規化(normed=True)したヒストグラムが表示でき、 他のプロットデータと重ね合わせることができます。
|
|
|
|
|
以下の様に正規分布関数_gaussを定義して、ヒストグラム、正規分布、ガンマ分布の推定結果を 重ね合わせてみます。
|
|
最初、flocを指定しないでフィッティングをしたため、fitdistrと結果が異なり 悩んでいました。しかしfloc=0とすることで、fitdistrと同じ結果が得られることが 分かり、floc制約をしない場合、どのような推定結果がでるのか調べて見ました。
今回、可視化はmatplotを使いました。floc=0の設定をしない方がより良くフィット していることが分かります。
4.37074130051 36.4314864684 6.71934493135 Saving 8.0 x 6.0 in image. 4.37074130051 36.4314864684 6.71934493135 Saving 8.0 x 6.0 in image. |
Saving 8.0 x 6.0 in image. Saving 8.0 x 6.0 in image. |
データに異常データが混ざっている場合の例として、以下の様に2種類の正規分布が混ざっている場合に ついて考えてみます。
|
分布は以下の様な混合正規分布として表されます。 $$ p(x) = \pi_0 \mathcal{N}(x | \mu_0, \sigma_0^2) + \pi_1 \mathcal{N}(x | \mu_1, \sigma_1^2) $$
パラメータを$ \theta = (\pi_0, \mu_0, \sigma_0^2, \pi_1, \mu_1, \sigma_1^2)$にまとめると、 対数尤度は、以下の様になります。 $$ L(\theta, \mathcal{D}) = \sum_{n=1}^N ln \left \{ \pi_0 \mathcal{N}(x | \mu_0, \sigma_0^2) + \pi_1 \mathcal{N}(x | \mu_1, \sigma_1^2) \right \} $$
これを$\mu_0$で偏微分すると、以下の様になります。 $$ 0 = \frac{\partial L}{\partial \mu_0} = \sum_{n=1}^N \frac{\pi_0 \mathcal{N}(x | \mu_0, \sigma_0^2) }{\pi_0 \mathcal{N}(x | \mu_0, \sigma_0^2) + \pi_1 \mathcal{N}(x | \mu_1, \sigma_1^2)} \frac{(x^{(n)} - \mu_0)}{\sigma_0^2} $$ 同様に$\sigma_0^2$で偏微分すると、以下の様になります。 $$ 0 = \frac{\partial L}{\partial \sigma_0^2} = \sum_{n=1}^N \frac{\pi_0 \mathcal{N}(x | \mu_0, \sigma_0^2) }{\pi_0 \mathcal{N}(x | \mu_0, \sigma_0^2) + \pi_1 \mathcal{N}(x | \mu_1, \sigma_1^2)} \frac{ \{-(x^{(n)} - \mu_0)^2 + \sigma_0^2 \}}{2} $$ ここで、データが$\pi_i$のどちらの集合に属するかその期待値を帰属度として以下の様に定義します。 $$ q_i^{(n)} = \frac{\pi_i \mathcal{N}(x | \mu_i, \sigma_i^2) }{\pi_0 \mathcal{N}(x | \mu_0, \sigma_0^2) + \pi_1 \mathcal{N}(x | \mu_1, \sigma_1^2)} $$
$\mu_i, \sigma_i^2$の推定値は、以下の様に求まります。 $$ \hat{\mu}_i = \frac{\sum_{n=1}^N q_i^{(n)} x^{(n)}}{\sum_{n'=1}^N q_i^{(n')}} $$ $$ \hat{\sigma}_i^2 = \frac{\sum_{n=1}^N q_i^{(n)} (x^{(n)} - \mu_i)^2}{\sum_{n'=1}^N q_i^{(n')}} $$
また、$\pi_i$は、以下の様になります。 $$ \hat{\pi}_i = \frac{1}{N} \sum_{n=1}^N q_i^{(n)} $$
混合正規分布をEMアルゴリズムを使って求めます。
EMアルゴリズムでは、以下の手順でパラメータの値を求めます。
以下の2つの正規分布をもつテストデータ1000個を生成します。 $$ (\pi_0, \pi_1) = (0.6, 0.4), (\mu_0, \sigma_0) = (3, 0.5), (\mu_1, \sigma_1) = (0, 3) $$
|
各パラメータの初期値を以下の様に設定します。 $$ (\pi_0, \pi_1) = (0.5, 0.5), (\mu_0, \sigma_0) = (5, 1.0), (\mu_1, \sigma_1) = (-5, 5.0) $$
|
EMアルゴリズムを10回繰り返します。
高々10回の繰り返しで、とても精度良くパラメータが推定されています。
|
0.602760455466 0.397239544534 2.98730105891 -0.292164057442 0.512590277559 2.85810740491 0.602760455466 0.397239544534 2.98730105891 -0.292164057442 0.512590277559 2.85810740491 |
|
|
最後にScikit-learnの混合正規分布パッケージsklearn.mixtureを使って、 davisの体重・身長のデータに対して、混合正規分布を求めてみます。
GMMは、異常なデータに影響を受けるため、12番目のデータを除いた集合で、 混合正規分布を求めます。
GMMの引数のn_componentsが求める分布の数です。RのMclustを使うとベストな分布数も計算できますが、 ここでは2として計算します。
|
GMM(covariance_type='full', init_params='wmc', min_covar=0.001, n_components=2, n_init=1, n_iter=100, params='wmc', random_state=None, thresh=None, tol=0.001) GMM(covariance_type='full', init_params='wmc', min_covar=0.001, n_components=2, n_init=1, n_iter=100, params='wmc', random_state=None, thresh=None, tol=0.001) |
大切なのは、GMMによる分類ではなく、その結果を可視化して特徴を把握することです。
GMMによってデータがどのグループに属するか分類した結果は、predict関数で取得し、 davis_e11のpredにセットします。
次にggplotでpredで色を変えて分布をプロットします。 たったこれだけで、GMMの結果を確認することができます。
<ggplot: (8738328197473)> <ggplot: (8738328197473)> |
Saving 11.0 x 8.0 in image. Saving 11.0 x 8.0 in image. |
GMMの結果は、$\mu$がmeans_、$\sigma$がcovars_、$\pi$がweights_変数にセットされています。
Sageのcontour_plot関数を使って正規分布をプロットしてみます。
|
|
|
|
混合正規分布の異常度は、以下の様になります。 $$ a(x') = -ln \left \{ \sum_{k=1}^K \hat{\pi}_k \mathcal{N}(x' | \hat{\mu}_k, \hat{\Sigma}_k) \right \} $$
|
|