ここでは、 人工知能に関する断創録 のTheanoに関連する記事をSageのノートブックで実装し、Thenoの修得を試みます。
今回は、TheanoのTutorialから畳み込みニューラルネット(CNN)を使った手書き数字認識を以下のページを参考にSageのノートブックで試してみます。 前半は、人工知能に関する断創録から参照されている「StatsFragments」さんのページを参考に畳み込みとMaxPoolingをSageで動かしてみました(ほぼ引用ですみません)。
SageでTheanoのtutorialのCNNを実行すると、DimShuffleでエラーになるため、今回もPythonを使用します。
そこで、ノートブックの処理系をSageからPythonに切り替えます。上部の左から4つめのプルダウンメニューから 「python」を選択してください。
最初に、theanoを使うのに必要なライブラリをインポートします。
|
「Theanoによる畳み込みニューラルネットワークの実装 (1)」から ニューラルネットワーク(CNN)の構成図を引用します。
最初に畳み込み(convolutionの矢印部)とプーリング(maxpoolingの矢印部)を持つLeNetConvPoolLayerが2層あり、 その後に多層パーセプトロン(HiddenLayer)と最後のロジスティック(LogisticRegression)から構成されています。
前半の畳み込みでは画像の特徴を際立たせるためのフィルタリングを行い、Max Poolingでは画像のずれを吸収し、 疎な結合を構成しています。
Theano で Deep Learning <3> : 畳み込みニューラルネットワーク の例題をSageで動かしながら、畳み込み演算の効果をみてみましょう。
Theanoでは、畳み込みの処理がパッケージtheano.tensor.nnet.convのconv2dで提供されています。
conv2dへの入力テンソルは、以下の様な次元を持ちます。
入力テンソルinputを4次元のテンソルとして定義します。
|
重みテンソルWの次元は、(2, 3, 9, 9)で、フィルタは9x9で、2x3(出力の特徴マップ数x入力の特徴マップ数)=6種類のフィルタが使われます。
以下の例では、重みの値は一様乱数を使って生成した意味のないものです。実際にはこのWが学習によって画像の特徴をより抽出できる形になります。
(-0.064150029909958411, 0.064150029909958411) (-0.064150029909958411, 0.064150029909958411) |
バイアスベクトルは、出力の特徴マップ数の長さを持ち、フィルタリングの後に各画像に加えられます。
そのため、4次元のテンソルと足し合わせられるようにdimshuffleで次元を調整します。 ここでは、conv_outの2つ目の次元と合うようにします。
array([-0.3943425 , 0.16818965]) array([-0.3943425 , 0.16818965]) |
array([[[[-0.3943425 ]], [[ 0.16818965]]]]) array([[[[-0.3943425 ]], [[ 0.16818965]]]]) |
|
通常の画像をどのようにしてTheanoで解析可能な形式にするのか、 Theano で Deep Learning <3> : 畳み込みニューラルネットワーク の例がとても分かりやすく参考になりました。
入力画像は、縦130px * 横120pxのRGBのJPEGファイルです。
これをswapaxesを使ってTheanoの入力にあった形式に変換していきます。
(130, 120, 3) (130, 120, 3) |
(3, 120, 130) (3, 120, 130) |
(3, 130, 120) (3, 130, 120) |
(1, 3, 130, 120) (1, 3, 130, 120) |
|
サンプル画像のRGBのそれぞれの画像にフィルタを使って畳み込みを施すと以下の様になります。
フィルタによる畳み込みで2つの特徴マップ(画像)は、異なる表情を示します。 これがフィルタによる特徴抽出の効果です。
(1, 2, 122, 112) (1, 2, 122, 112) |
|
|
|
畳み込みに使用されたフィルタの形を表示してみると訳の分からないランダムな模様に見えます。 畳み込みニューラルネットの学習によってこのランダムなフィルタに少しずつ特徴的な形がでてきます。
|
Max Pooling 法とは、あるウィンドウサイズの中で 最大の値を代表値としてサンプリングする方法です。
Theanoのtheano.tensor.signal.downsampleパッケージのmax_pool_2dでMax Poolingを提供しています。
|
invals[0, 0, :, :] = [[ 4.17022005e-01 7.20324493e-01 1.14374817e-04 3.02332573e-01 1.46755891e-01] [ 9.23385948e-02 1.86260211e-01 3.45560727e-01 3.96767474e-01 5.38816734e-01] [ 4.19194514e-01 6.85219500e-01 2.04452250e-01 8.78117436e-01 2.73875932e-02] [ 6.70467510e-01 4.17304802e-01 5.58689828e-01 1.40386939e-01 1.98101489e-01] [ 8.00744569e-01 9.68261576e-01 3.13424178e-01 6.92322616e-01 8.76389152e-01]] output[0, 0, :, :] = [[ 0.72032449 0.39676747] [ 0.6852195 0.87811744]] invals[0, 0, :, :] = [[ 4.17022005e-01 7.20324493e-01 1.14374817e-04 3.02332573e-01 1.46755891e-01] [ 9.23385948e-02 1.86260211e-01 3.45560727e-01 3.96767474e-01 5.38816734e-01] [ 4.19194514e-01 6.85219500e-01 2.04452250e-01 8.78117436e-01 2.73875932e-02] [ 6.70467510e-01 4.17304802e-01 5.58689828e-01 1.40386939e-01 1.98101489e-01] [ 8.00744569e-01 9.68261576e-01 3.13424178e-01 6.92322616e-01 8.76389152e-01]] output[0, 0, :, :] = [[ 0.72032449 0.39676747] [ 0.6852195 0.87811744]] |
以下の例では、2x2のマトリックスからその最大の要素(0.7203)が抽出されていることが確認できます。
[[ 0.417022 0.72032449] [ 0.09233859 0.18626021]] 0.720324493442 [[ 0.417022 0.72032449] [ 0.09233859 0.18626021]] 0.720324493442 |
サンプル画像のMax Pooling後の画像を表示してみると、左の画像は瞳の白い部分が強調され、右の部分はまぶたの部分が強調されているように見えます。
(1, 2, 61, 56) (1, 2, 61, 56) |
|
LeNetに関する Theano で Deep Learning <3> : 畳み込みニューラルネットワーク の説明(図を引用)も分かりやすいです。
特徴マップに畳み込みを実行し、Max Poolingで疎な結合とし、これを複数層連結して最後に、多層パーセプトロンと目的の活性化関数を施す、 一連の流れが、とても分かりやすいです。
LeNetというのは、畳み込みニューラルネットを発明したLeCunの最初のニューラルネットの名前に由来するのだそうです。
Deep Learning TutorialのLeNetConvPoolLayerクラスは、畳み込み層とプーリング層のペアを実装しています。
重みWの初期値に、$[ - \sqrt{\frac{6}{in + out}},\sqrt{\frac{6}{in + out}}]$を与えるのは、 Theanoによる多層パーセプトロンの実装 で紹介されている、活性化関数にtanhを使う時の収束の良いWの初期値です。
|
2つのLeNetConvPoolLayerと全結合したHiddenLayerとLogisticRegressionを組み合わせて、 MNISTの数字文字認識のMini版を試してみます。
Theanoによる畳み込みニューラルネットワークの実装 (1) との違いは、最後に収束したモデルを以下の様に保存しているところです。
# dump layers with gzip.open(DATA+'model.pkl.gz', 'wb') as f: pickle.dump([layer0_input, layer0, layer1, layer2_input, layer2, layer3], f)
|
さくらのVPSでこの計算をすると約2時間CPUを占有してしまいますので、以下の処理は実行しないでください。
または、このノートブックをダウンロードして、ローカルのSageで試してみてください。
WARNING: Output truncated! full_output.txt ... loading data building the model ... train model ... epoch 1, minibatch 10/10, validation error 38.000000 % *** iter 9 / patience 10000 epoch 1, minibatch 10/10, test error of best model 38.400000 % epoch 2, minibatch 10/10, validation error 34.200000 % *** iter 19 / patience 10000 epoch 2, minibatch 10/10, test error of best model 34.600000 % epoch 3, minibatch 10/10, validation error 22.700000 % *** iter 29 / patience 10000 epoch 3, minibatch 10/10, test error of best model 23.300000 % epoch 4, minibatch 10/10, validation error 17.700000 % *** iter 39 / patience 10000 epoch 4, minibatch 10/10, test error of best model 18.400000 % epoch 5, minibatch 10/10, validation error 16.100000 % *** iter 49 / patience 10000 epoch 5, minibatch 10/10, test error of best model 16.600000 % epoch 6, minibatch 10/10, validation error 14.900000 % *** iter 59 / patience 10000 epoch 6, minibatch 10/10, test error of best model 15.100000 % epoch 7, minibatch 10/10, validation error 12.900000 % *** iter 69 / patience 10000 epoch 7, minibatch 10/10, test error of best model 13.800000 % epoch 8, minibatch 10/10, validation error 12.500000 % *** iter 79 / patience 10000 epoch 8, minibatch 10/10, test error of best model 13.000000 % epoch 9, minibatch 10/10, validation error 12.100000 % *** iter 89 / patience 10000 epoch 9, minibatch 10/10, test error of best model 12.000000 % epoch 10, minibatch 10/10, validation error 11.500000 % *** iter 99 / patience 10000 epoch 10, minibatch 10/10, test error of best model 11.800000 % epoch 11, minibatch 10/10, validation error 11.000000 % *** iter 109 / patience 10000 epoch 11, minibatch 10/10, test error of best model 11.700000 % epoch 12, minibatch 10/10, validation error 10.500000 % *** iter 119 / patience 10000 epoch 12, minibatch 10/10, test error of best model 11.200000 % epoch 13, minibatch 10/10, validation error 10.400000 % *** iter 129 / patience 10000 epoch 13, minibatch 10/10, test error of best model 10.900000 % epoch 14, minibatch 10/10, validation error 10.000000 % *** iter 139 / patience 10000 epoch 14, minibatch 10/10, test error of best model 10.600000 % epoch 15, minibatch 10/10, validation error 9.700000 % *** iter 149 / patience 10000 epoch 15, minibatch 10/10, test error of best model 10.400000 % epoch 16, minibatch 10/10, validation error 9.000000 % *** iter 159 / patience 10000 epoch 16, minibatch 10/10, test error of best model 10.000000 % epoch 17, minibatch 10/10, validation error 9.100000 % epoch 18, minibatch 10/10, validation error 8.800000 % *** iter 179 / patience 10000 epoch 18, minibatch 10/10, test error of best model 9.600000 % epoch 19, minibatch 10/10, validation error 8.400000 % *** iter 189 / patience 10000 epoch 19, minibatch 10/10, test error of best model 9.400000 % epoch 20, minibatch 10/10, validation error 7.900000 % ... epoch 147, minibatch 10/10, validation error 3.200000 % epoch 148, minibatch 10/10, validation error 3.200000 % epoch 149, minibatch 10/10, validation error 3.200000 % epoch 150, minibatch 10/10, validation error 3.200000 % epoch 151, minibatch 10/10, validation error 3.200000 % epoch 152, minibatch 10/10, validation error 3.200000 % epoch 153, minibatch 10/10, validation error 3.100000 % epoch 154, minibatch 10/10, validation error 3.100000 % epoch 155, minibatch 10/10, validation error 3.100000 % epoch 156, minibatch 10/10, validation error 3.200000 % epoch 157, minibatch 10/10, validation error 3.200000 % epoch 158, minibatch 10/10, validation error 3.200000 % epoch 159, minibatch 10/10, validation error 3.200000 % epoch 160, minibatch 10/10, validation error 3.200000 % epoch 161, minibatch 10/10, validation error 3.200000 % epoch 162, minibatch 10/10, validation error 3.200000 % epoch 163, minibatch 10/10, validation error 3.200000 % epoch 164, minibatch 10/10, validation error 3.200000 % epoch 165, minibatch 10/10, validation error 3.200000 % epoch 166, minibatch 10/10, validation error 3.200000 % epoch 167, minibatch 10/10, validation error 3.200000 % epoch 168, minibatch 10/10, validation error 3.200000 % epoch 169, minibatch 10/10, validation error 3.100000 % epoch 170, minibatch 10/10, validation error 3.100000 % epoch 171, minibatch 10/10, validation error 3.100000 % epoch 172, minibatch 10/10, validation error 3.100000 % epoch 173, minibatch 10/10, validation error 3.100000 % epoch 174, minibatch 10/10, validation error 3.100000 % epoch 175, minibatch 10/10, validation error 3.100000 % epoch 176, minibatch 10/10, validation error 3.100000 % epoch 177, minibatch 10/10, validation error 3.100000 % epoch 178, minibatch 10/10, validation error 3.100000 % epoch 179, minibatch 10/10, validation error 3.100000 % epoch 180, minibatch 10/10, validation error 3.100000 % epoch 181, minibatch 10/10, validation error 3.100000 % epoch 182, minibatch 10/10, validation error 3.100000 % epoch 183, minibatch 10/10, validation error 3.100000 % epoch 184, minibatch 10/10, validation error 3.100000 % epoch 185, minibatch 10/10, validation error 3.100000 % epoch 186, minibatch 10/10, validation error 3.100000 % epoch 187, minibatch 10/10, validation error 3.100000 % epoch 188, minibatch 10/10, validation error 3.100000 % epoch 189, minibatch 10/10, validation error 3.100000 % epoch 190, minibatch 10/10, validation error 3.000000 % *** iter 1899 / patience 10000 epoch 190, minibatch 10/10, test error of best model 2.900000 % epoch 191, minibatch 10/10, validation error 3.000000 % epoch 192, minibatch 10/10, validation error 3.000000 % epoch 193, minibatch 10/10, validation error 3.000000 % epoch 194, minibatch 10/10, validation error 3.000000 % epoch 195, minibatch 10/10, validation error 3.000000 % epoch 196, minibatch 10/10, validation error 3.000000 % epoch 197, minibatch 10/10, validation error 3.000000 % epoch 198, minibatch 10/10, validation error 3.000000 % epoch 199, minibatch 10/10, validation error 3.000000 % epoch 200, minibatch 10/10, validation error 3.000000 % Optimization complete. Best validation score of 3.000000 % obtained at iteration 1900, with test performance 2.900000 % Ran for 119.10m WARNING: Output truncated! full_output.txt ... loading data building the model ... train model ... epoch 1, minibatch 10/10, validation error 38.000000 % *** iter 9 / patience 10000 epoch 1, minibatch 10/10, test error of best model 38.400000 % epoch 2, minibatch 10/10, validation error 34.200000 % *** iter 19 / patience 10000 epoch 2, minibatch 10/10, test error of best model 34.600000 % epoch 3, minibatch 10/10, validation error 22.700000 % *** iter 29 / patience 10000 epoch 3, minibatch 10/10, test error of best model 23.300000 % epoch 4, minibatch 10/10, validation error 17.700000 % *** iter 39 / patience 10000 epoch 4, minibatch 10/10, test error of best model 18.400000 % epoch 5, minibatch 10/10, validation error 16.100000 % *** iter 49 / patience 10000 epoch 5, minibatch 10/10, test error of best model 16.600000 % epoch 6, minibatch 10/10, validation error 14.900000 % *** iter 59 / patience 10000 epoch 6, minibatch 10/10, test error of best model 15.100000 % epoch 7, minibatch 10/10, validation error 12.900000 % *** iter 69 / patience 10000 epoch 7, minibatch 10/10, test error of best model 13.800000 % epoch 8, minibatch 10/10, validation error 12.500000 % *** iter 79 / patience 10000 epoch 8, minibatch 10/10, test error of best model 13.000000 % epoch 9, minibatch 10/10, validation error 12.100000 % *** iter 89 / patience 10000 epoch 9, minibatch 10/10, test error of best model 12.000000 % epoch 10, minibatch 10/10, validation error 11.500000 % *** iter 99 / patience 10000 epoch 10, minibatch 10/10, test error of best model 11.800000 % epoch 11, minibatch 10/10, validation error 11.000000 % *** iter 109 / patience 10000 epoch 11, minibatch 10/10, test error of best model 11.700000 % epoch 12, minibatch 10/10, validation error 10.500000 % *** iter 119 / patience 10000 epoch 12, minibatch 10/10, test error of best model 11.200000 % epoch 13, minibatch 10/10, validation error 10.400000 % *** iter 129 / patience 10000 epoch 13, minibatch 10/10, test error of best model 10.900000 % epoch 14, minibatch 10/10, validation error 10.000000 % *** iter 139 / patience 10000 epoch 14, minibatch 10/10, test error of best model 10.600000 % epoch 15, minibatch 10/10, validation error 9.700000 % *** iter 149 / patience 10000 epoch 15, minibatch 10/10, test error of best model 10.400000 % epoch 16, minibatch 10/10, validation error 9.000000 % *** iter 159 / patience 10000 epoch 16, minibatch 10/10, test error of best model 10.000000 % epoch 17, minibatch 10/10, validation error 9.100000 % epoch 18, minibatch 10/10, validation error 8.800000 % *** iter 179 / patience 10000 epoch 18, minibatch 10/10, test error of best model 9.600000 % epoch 19, minibatch 10/10, validation error 8.400000 % *** iter 189 / patience 10000 epoch 19, minibatch 10/10, test error of best model 9.400000 % epoch 20, minibatch 10/10, validation error 7.900000 % ... epoch 147, minibatch 10/10, validation error 3.200000 % epoch 148, minibatch 10/10, validation error 3.200000 % epoch 149, minibatch 10/10, validation error 3.200000 % epoch 150, minibatch 10/10, validation error 3.200000 % epoch 151, minibatch 10/10, validation error 3.200000 % epoch 152, minibatch 10/10, validation error 3.200000 % epoch 153, minibatch 10/10, validation error 3.100000 % epoch 154, minibatch 10/10, validation error 3.100000 % epoch 155, minibatch 10/10, validation error 3.100000 % epoch 156, minibatch 10/10, validation error 3.200000 % epoch 157, minibatch 10/10, validation error 3.200000 % epoch 158, minibatch 10/10, validation error 3.200000 % epoch 159, minibatch 10/10, validation error 3.200000 % epoch 160, minibatch 10/10, validation error 3.200000 % epoch 161, minibatch 10/10, validation error 3.200000 % epoch 162, minibatch 10/10, validation error 3.200000 % epoch 163, minibatch 10/10, validation error 3.200000 % epoch 164, minibatch 10/10, validation error 3.200000 % epoch 165, minibatch 10/10, validation error 3.200000 % epoch 166, minibatch 10/10, validation error 3.200000 % epoch 167, minibatch 10/10, validation error 3.200000 % epoch 168, minibatch 10/10, validation error 3.200000 % epoch 169, minibatch 10/10, validation error 3.100000 % epoch 170, minibatch 10/10, validation error 3.100000 % epoch 171, minibatch 10/10, validation error 3.100000 % epoch 172, minibatch 10/10, validation error 3.100000 % epoch 173, minibatch 10/10, validation error 3.100000 % epoch 174, minibatch 10/10, validation error 3.100000 % epoch 175, minibatch 10/10, validation error 3.100000 % epoch 176, minibatch 10/10, validation error 3.100000 % epoch 177, minibatch 10/10, validation error 3.100000 % epoch 178, minibatch 10/10, validation error 3.100000 % epoch 179, minibatch 10/10, validation error 3.100000 % epoch 180, minibatch 10/10, validation error 3.100000 % epoch 181, minibatch 10/10, validation error 3.100000 % epoch 182, minibatch 10/10, validation error 3.100000 % epoch 183, minibatch 10/10, validation error 3.100000 % epoch 184, minibatch 10/10, validation error 3.100000 % epoch 185, minibatch 10/10, validation error 3.100000 % epoch 186, minibatch 10/10, validation error 3.100000 % epoch 187, minibatch 10/10, validation error 3.100000 % epoch 188, minibatch 10/10, validation error 3.100000 % epoch 189, minibatch 10/10, validation error 3.100000 % epoch 190, minibatch 10/10, validation error 3.000000 % *** iter 1899 / patience 10000 epoch 190, minibatch 10/10, test error of best model 2.900000 % epoch 191, minibatch 10/10, validation error 3.000000 % epoch 192, minibatch 10/10, validation error 3.000000 % epoch 193, minibatch 10/10, validation error 3.000000 % epoch 194, minibatch 10/10, validation error 3.000000 % epoch 195, minibatch 10/10, validation error 3.000000 % epoch 196, minibatch 10/10, validation error 3.000000 % epoch 197, minibatch 10/10, validation error 3.000000 % epoch 198, minibatch 10/10, validation error 3.000000 % epoch 199, minibatch 10/10, validation error 3.000000 % epoch 200, minibatch 10/10, validation error 3.000000 % Optimization complete. Best validation score of 3.000000 % obtained at iteration 1900, with test performance 2.900000 % Ran for 119.10m |
Theanoによる畳み込みニューラルネットワークの実装 (1) のフィルタの可視化を使って最初のフィルタを可視化してみました。
当たり前かもしれませんが、 Theanoによる畳み込みニューラルネットワークの実装 (1) と同じ結果が得られました。
各フィルタの形を見ると数字を特徴付ける形の断片のようにも思われますが、ややシャープさに欠けるように思います。
|
|
|
畳み込みニューラルネットの練習用データから求まったモデルを使って、 テスト用データを予測してみましょう。
mini_mnist.pkl.gzからテスト用の画像データを読み込み、batch_sizeはCNNの計算と同じ500として、 モデルの入力layer0_inputとLogisticRegressionの0〜9の数字に対する確率p_y_given_xを返す 関数predict_modelを定義し、テストデータの最初の500サンプルに対して結果を求めます。
predict_model = theano.function( [layer0_input], layer3.p_y_given_x, )
|
|
ロジスティック回帰の結果は、
['(7, 0.988)', '(2, 0.733)', '(1, 0.951)', '(0, 0.995)', '(4, 0.936)', '(1, 0.979)', '(4, 0.957)', '(9, 0.912)', '(6, 0.912)', '(9, 0.844)', '(0, 0.904)', '(6, 0.668)', '(9, 0.917)', '(0, 0.983)', '(1, 0.991)', '(5, 0.832)', '(9, 0.790)', '(7, 0.986)', '(3, 0.659)', '(4, 0.979)', '(9, 0.775)', '(6, 0.952)', '(6, 0.801)', '(5, 0.955)', '(4, 0.862)']多層回帰の結果は、
['(7, 0.999)', '(2, 0.780)', '(1, 0.991)', '(0, 0.999)', '(4, 0.993)', '(1, 0.998)', '(4, 0.995)', '(9, 0.999)', '(6, 0.995)', '(9, 0.982)', '(0, 0.996)', '(6, 0.994)', '(9, 0.996)', '(0, 0.993)', '(1, 1.000)', '(5, 0.974)', '(9, 0.995)', '(7, 0.998)', '(3, 0.980)', '(4, 1.000)', '(9, 0.911)', '(6, 0.905)', '(6, 0.933)', '(5, 1.000)', '(4, 0.994)']であることから、CNNではかなり認識率が良くなっているように思えます。
予測計算終了 ['(7, 1.000)', '(2, 0.999)', '(1, 1.000)', '(0, 0.998)', '(4, 1.000)', '(1, 1.000)', '(4, 1.000)', '(9, 0.999)', '(5, 0.985)', '(9, 0.995)', '(0, 1.000)', '(6, 1.000)', '(9, 1.000)', '(0, 1.000)', '(1, 1.000)', '(5, 0.997)', '(9, 0.999)', '(7, 1.000)', '(3, 0.976)', '(4, 1.000)', '(9, 0.995)', '(6, 0.992)', '(6, 0.999)', '(5, 1.000)', '(4, 1.000)'] 予測計算終了 ['(7, 1.000)', '(2, 0.999)', '(1, 1.000)', '(0, 0.998)', '(4, 1.000)', '(1, 1.000)', '(4, 1.000)', '(9, 0.999)', '(5, 0.985)', '(9, 0.995)', '(0, 1.000)', '(6, 1.000)', '(9, 1.000)', '(0, 1.000)', '(1, 1.000)', '(5, 0.997)', '(9, 0.999)', '(7, 1.000)', '(3, 0.976)', '(4, 1.000)', '(9, 0.995)', '(6, 0.992)', '(6, 0.999)', '(5, 1.000)', '(4, 1.000)'] |
同様にlayer0の出力画像も求めることができます。
|
テストデータの最初の7をフィルタリングした結果は、以下の様になりました。
(500, 20, 12, 12) (500, 20, 12, 12) |
|