ここでは、 人工知能に関する断創録 のTheanoに関連する記事をSageのノートブックで実装し、Thenoの修得を試みます。
今回は、TheanoのTutorialからMNISTの手書き数字認識を以下のページを参考にSageのノートブックで試してみます。
SageでTheanoのtutorialのMNISTを実行すると、以下のy_predがベクトルではなくスカラーになり、うまく処理できません。
self.y_pred = T.argmax(self.p_y_given_x, axis=1)
そこで、ノートブックの処理系をSageからPythonに切り替えます。上部の左から4つめのプルダウンメニューから 「python」を選択してください。
今回は、ちょっと多くのライブラリを使用します。
|
TheanoのTutorialで使用しているMNISTは手書き数字の画像データセットは、 cPickleモジュールでロードできる形式に圧縮したmnist.pkl.gzが以下のサイトから ダウンロードできます。
MNISTのデータには、70000の手書き数字データが収録されおり、それを以下の様に分割して使用します。
TensorFlowのMNIST for ML BeginnersのMNISTデータの説明が分かりやすいので、引用します。
|
Sageサーバのメモリは1Gと少ないため、MNISTのデータをそのまま使用するとメモリ不足になります。 そこで、1/10のサイズのサブセットデータmini_mnist.pkl.gzを作成し、これを使用して数字を認識することにします。
|
2 (5000, 784) (5000,) <type 'numpy.int64'> 2 (5000, 784) (5000,) <type 'numpy.int64'> |
|
どのような画像データが入っているのか、訓練用データ(train_set)の最初の100個を表示してみましょう。
|
Theanoの共有変数を使用すると、GPUのメモリ領域に保存され、学習時(update)に高速に読み書きできます。 共有変数の値を取り出すときに、get_value()を使用するのは、GPUにセットされたデータをCPUのメモリにコピー するためと考えられます。
また、GPUで使用するデータは必ずfloat型で格納しなければなりません。 そのため、ラベルshared_yは、T.cast(shared_y, 'int32')でキャストして返しています。
これでdatasets変数に訓練用、検証用、テスト用のデータがセットされました。
|
ニューラルネットの入力$x_i$と出力$u_i$の関係は、重み$W_i$とバイアス$b_i$を使って以下の関係になります。 $$ u_i = \sum_j W_{i,j} x_j + b_i $$
これを分かりやすく説明した図を再度、TensorFlowのMNIST for ML Beginnersから引用します。
この重みWとバイアスbをTheanoの共有変数でセットしているのが、以下の箇所です。
# 重み行列を初期化 self.W = theano.shared(value=np.zeros((n_in, n_out), dtype=theano.config.floatX), name='W', borrow=True) # バイアスベクトルを初期化 self.b = theano.shared(value=np.zeros((n_out,), dtype=theano.config.floatX), name='b', borrow=True)
多クラス分類の活性化関数として使用されるのが、ソフトマック関数(Softmax)で、出力層のクラスiに分類される確率は、以下の様に表されます。 $$ y_i = \frac{exp(u_i)}{\sum_j exp(u_j)} $$
この部分をTheanoのシンボルで表現している箇所が以下の部分です。
# 各サンプルが各クラスに分類される確率を計算するシンボル # 全データを行列化してまとめて計算している # 出力は(n_samples, n_out)の行列 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)
人工知能に関する断創録の貴重なコメントがあります。
確率が求まったら最終的に一番高い確率が得られるクラスをy_predにセットします。 $$ y_{pred} = argmax(y) $$ TheanoではT.argmax()にaxis=1を指定することでp_y_given_xの各行(サンプルに相当)において一番確率が高いインデックス(クラスに相当)がまとめて取得できるそうです。
# 確率が最大のクラスのインデックスを計算 # 出力は(n_samples,)のベクトル self.y_pred = T.argmax(self.p_y_given_x, axis=1)
分類されたクラスkだけ1で、他は0のベクトルを$d_n$とすると、事後分布は、以下の様になります。 $$ p(d | x) = \prod_{k=1}^K p(C_k | x)^{d_k} $$ 訓練データ${ (x_n, d_n)}(n=1,...,N)$に対するwの尤度は、以下の様になります。 $$ L(W) = \prod_{n=1}^N p(d_n | x_n; W) = \prod_{n=1}^N \prod_{k=1}^K p(c_k|x_n)^{d_{nk}} = \prod_{n=1}^N \prod_{k=1}^K (y_k(x; W))^{d_{nk}} $$
負の対数尤度を誤差関数とすると $$ E(W) = - \sum_{n=1}^N \sum_{k=1}^K d_{nk} log y_k(x_n; W) $$
人工知能に関する断創録では、Sumの代わりに平均meanを使っていることに注意!
# 式通りに計算するとsumだがmeanの方がよい return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y])
|
モデルの訓練には、ミニバッチ確率的勾配降下法(MSGD)を使用しています。 確率的勾配降下法(SGD)はただ1つのサンプルで1回だけパラメータを更新するのに対し、 少数のサンプルをひとまとめにしてその単位で重みを更新します。 このひとまとめにしたサンプル集合をミニバッチ(minbatch)と呼びます。
何番目のミニバッチを使用するかを示すのがシンボルindexです。 コードのfunction()のgivensで定義されている部分がミニバッチの設定箇所です。
givens={ x: train_set_x[index * batch_size: (index + 1) * batch_size], y: train_set_y[index * batch_size: (index + 1) * batch_size] }
確率的勾配降下法で使用するコスト関数とその微分は、負の対数尤度を計算する negative_log_likelihood関数とT.gradの箇所で計算しています。
# 誤差(コスト)を計算 => 最小化したい cost = classifier.negative_log_likelihood(y) # コスト関数のtheta = (W,b)の微分を計算 g_W = T.grad(cost=cost, wrt=classifier.W) g_b = T.grad(cost=cost, wrt=classifier.b)
パラメータの更新式は、Wとbの2個をタプルとして指定します。
# パラメータ更新式 updates = [(classifier.W, classifier.W - learning_rate * g_W), (classifier.b, classifier.b - learning_rate * g_b)]
関数の定義は以下の通りです。
モデルの当てはまりの良さをエラー率で評価しています。LogisticRegressionのerrors関数で計算します。
def errors(self, y): """分類の誤差率を計算するシンボルを返す yにはinputに対応する正解クラスを渡す""" return T.mean(T.neq(self.y_pred, y))T.neq(self.y_pred, y)で予測クラスと正解クラスが異なる要素の数を求め、 その平均を取ることでエラー率を計算しています。
検証用データvalid_setのエラー率と訓練用データtrain_setのエラー率から適合(overfitting)を防ぐために、 検証用データのエラー率が増加した時点で学習を打ち切る早期終了(Early-Stopping)というテクニックを使用します。
人工知能に関する断創録に解説されているEarly-Stoppingの説明を引用します。
|
... building the model ... training the model epoch 1, minibatch 83/83, validation error 16.458333 % epoch 1, minibatch 83/83, test error of best model 15.833333 % epoch 2, minibatch 83/83, validation error 14.375000 % epoch 2, minibatch 83/83, test error of best model 12.916667 % epoch 3, minibatch 83/83, validation error 13.333333 % epoch 3, minibatch 83/83, test error of best model 11.666667 % epoch 4, minibatch 83/83, validation error 13.541667 % epoch 5, minibatch 83/83, validation error 13.750000 % epoch 6, minibatch 83/83, validation error 13.125000 % epoch 6, minibatch 83/83, test error of best model 10.833333 % epoch 7, minibatch 83/83, validation error 12.916667 % epoch 7, minibatch 83/83, test error of best model 10.208333 % epoch 8, minibatch 83/83, validation error 13.125000 % epoch 9, minibatch 83/83, validation error 13.125000 % epoch 10, minibatch 83/83, validation error 12.916667 % epoch 11, minibatch 83/83, validation error 12.916667 % epoch 12, minibatch 83/83, validation error 12.291667 % epoch 12, minibatch 83/83, test error of best model 8.750000 % epoch 13, minibatch 83/83, validation error 12.291667 % epoch 14, minibatch 83/83, validation error 12.500000 % epoch 15, minibatch 83/83, validation error 12.500000 % epoch 16, minibatch 83/83, validation error 12.500000 % epoch 17, minibatch 83/83, validation error 12.500000 % epoch 18, minibatch 83/83, validation error 12.500000 % epoch 19, minibatch 83/83, validation error 12.708333 % epoch 20, minibatch 83/83, validation error 12.708333 % epoch 21, minibatch 83/83, validation error 12.708333 % epoch 22, minibatch 83/83, validation error 12.708333 % epoch 23, minibatch 83/83, validation error 13.125000 % Optimization complete with best validation score of 12.291667 %,with test performance 8.750000 % The code run for 24 epochs, with 7.298651 epochs/sec Ran for 3.3s ... building the model ... training the model epoch 1, minibatch 83/83, validation error 16.458333 % epoch 1, minibatch 83/83, test error of best model 15.833333 % epoch 2, minibatch 83/83, validation error 14.375000 % epoch 2, minibatch 83/83, test error of best model 12.916667 % epoch 3, minibatch 83/83, validation error 13.333333 % epoch 3, minibatch 83/83, test error of best model 11.666667 % epoch 4, minibatch 83/83, validation error 13.541667 % epoch 5, minibatch 83/83, validation error 13.750000 % epoch 6, minibatch 83/83, validation error 13.125000 % epoch 6, minibatch 83/83, test error of best model 10.833333 % epoch 7, minibatch 83/83, validation error 12.916667 % epoch 7, minibatch 83/83, test error of best model 10.208333 % epoch 8, minibatch 83/83, validation error 13.125000 % epoch 9, minibatch 83/83, validation error 13.125000 % epoch 10, minibatch 83/83, validation error 12.916667 % epoch 11, minibatch 83/83, validation error 12.916667 % epoch 12, minibatch 83/83, validation error 12.291667 % epoch 12, minibatch 83/83, test error of best model 8.750000 % epoch 13, minibatch 83/83, validation error 12.291667 % epoch 14, minibatch 83/83, validation error 12.500000 % epoch 15, minibatch 83/83, validation error 12.500000 % epoch 16, minibatch 83/83, validation error 12.500000 % epoch 17, minibatch 83/83, validation error 12.500000 % epoch 18, minibatch 83/83, validation error 12.500000 % epoch 19, minibatch 83/83, validation error 12.708333 % epoch 20, minibatch 83/83, validation error 12.708333 % epoch 21, minibatch 83/83, validation error 12.708333 % epoch 22, minibatch 83/83, validation error 12.708333 % epoch 23, minibatch 83/83, validation error 13.125000 % Optimization complete with best validation score of 12.291667 %,with test performance 8.750000 % The code run for 24 epochs, with 7.298651 epochs/sec Ran for 3.3s |
テスト用データtest_setの最初の25個に対して認識を行い、予測結果と画像を合わせて表示してみます。
1/10の学習データでもそこそこ良い結果が得られています。
Predicted values for the first 10 examples in test set: ['(7, 0.988)', '(2, 0.733)', '(1, 0.951)', '(0, 0.995)', '(4, 0.936)', '(1, 0.979)', '(4, 0.957)', '(9, 0.912)', '(6, 0.912)', '(9, 0.844)', '(0, 0.904)', '(6, 0.668)', '(9, 0.917)', '(0, 0.983)', '(1, 0.991)', '(5, 0.832)', '(9, 0.790)', '(7, 0.986)', '(3, 0.659)', '(4, 0.979)', '(9, 0.775)', '(6, 0.952)', '(6, 0.801)', '(5, 0.955)', '(4, 0.862)'] Predicted values for the first 10 examples in test set: ['(7, 0.988)', '(2, 0.733)', '(1, 0.951)', '(0, 0.995)', '(4, 0.936)', '(1, 0.979)', '(4, 0.957)', '(9, 0.912)', '(6, 0.912)', '(9, 0.844)', '(0, 0.904)', '(6, 0.668)', '(9, 0.917)', '(0, 0.983)', '(1, 0.991)', '(5, 0.832)', '(9, 0.790)', '(7, 0.986)', '(3, 0.659)', '(4, 0.979)', '(9, 0.775)', '(6, 0.952)', '(6, 0.801)', '(5, 0.955)', '(4, 0.862)'] |
|
最後に各クラスを識別した重みWがどのように形に求まったのかみてみましょう。
赤い部分で正で、葵部分が負の値です。各数字の特徴的な部分に赤の部分が見られることが分かります。
|
|