SVM

3178 days ago by takepwave

Hiroshi TAKEMOTO (take@pwv.co.jp)

はじめに

SVMは、オーバーフィッティングを避けて効率よく識別関数を求めることができる 手法です。

「集合知」の9章で紹介されているSVMをSageを使ってまとめてみます。

簡単な例題

いきなりSVMに進む前にクラス分けを簡単な例を使って解いてみます。

下図は、赤(-1)のグループと青(1)のグループの分布です。

# 線形クラス分類の例 # データの用意 c1 = [[1,2],[1,4],[2,4]] c2 = [[2,1],[5,1],[4,2]] # プロットして分布を確認 pl1 = list_plot(c1, rgbcolor='red') pl2 = list_plot(c2, rgbcolor ='blue') (pl1+pl2).show(xmin=0, xmax=5, ymin=0, figsize=5) 
       

線形分類

青点の値に1、赤点の値に-1をセットし、$w_1 x_1 + w_2 x_2 + b$の線形モデルを find_fit関数を使って解いて、$w_1, w2, b$の値を求めます。

find_fitのデータは、$x_1, x_2, 値$の順にセットします。

# 各クラスに判別値をセット v1 = [-1, -1, -1] v2 = [1, 1, 1] # (x, y, 判別値)のリストを作成 data = [flatten((pt, v)) for (pt, v) in zip(c1 + c2, v1 + v2)] data 
       
[[1, 2, -1], [1, 4, -1], [2, 4, -1], [2, 1, 1], [5, 1, 1], [4, 2, 1]]
[[1, 2, -1], [1, 4, -1], [2, 4, -1], [2, 1, 1], [5, 1, 1], [4, 2, 1]]
# 最もフィットするw1, w2, bを求める var('x1 x2 w1 w2 b') model(x1, x2) = w1*x1 + w2*x2 + b fit = find_fit(data, model, solution_dict=True); print fit # 求まった解(判別式)を返す関数を定義します f(x1, x2) = model.subs(fit); 
       
{b: 0.1962962962945436, w2: -0.4333333333364595, w1:
0.32592592592445524}
{b: 0.1962962962945436, w2: -0.4333333333364595, w1: 0.32592592592445524}

結果の表示

implicit_plot関数を使ってデータ(赤、青)と判別式が0となる線を表示します。 上手く赤と青の点を分離しているのが分かります。

# 判別式が0の線を表示 pl6 = implicit_plot(f(x1, x2) == 0, (x1, 0, 5), (x2, 0, 5)) (pl1 + pl2 +pl6).show(xmin=0, xmax=5, ymin=0, figsize=5) 
       

求まった判別式がどのような形なのかplot3dを使って表示してみます。 データ(赤、青)の判別式の値も合わせて表示してみます。

一つの平面上として表されています。

pl3 = plot3d(f(x1, x2), (x1, 0, 5), (x2, 0, 5)) pl4 = list_plot([(x, y, f(x, y)) for (x, y) in c1], rgbcolor='red') pl5 = list_plot([(x, y, f(x, y)) for (x, y) in c2], rgbcolor='blue') (pl3+pl4+pl5).show(xmin=0, xmax=5) 
       

SVMを使った分類

SVM(サポートベクターマシン)は、クラスの境界線と分離平面(超平面) の距離(マージン)を最大になるようにクラスを分類します。

先ほどの例では、下図のように境界線(半線)の中間に$y = x$の分離平面が、 あります。(線形分類の境界線とずれていることに注意して下さい)

境界線を求めるとき使った学習データのことを「サポートベクター」と呼びます。

# SVMでは、クラスの境界線と分離平面(超平面)の距離を最大になるように求めます sv = plot(lambda x : x, (x, 0, 5)) cl1 = plot(lambda x : x + 1, (x, 0, 5), linestyle='dashed') cl2 = plot(lambda x : x - 1, (x, 0, 5), linestyle='dashed') (sv+cl1+cl2+pl1+pl2).show(xmin=0, xmax=5, figsize=5) # 青の実線が分離平面で、半線がクラスの境界線です 
       

LIBSVMのインストール

sageを使ってSVMを計算したいところですが、sageにはSVMを計算する関数がありません。 そこで、LIBSVMをインストールします。

LIBSVMの ホームページ のDownload LIBSVMから最新のtar.gzファイルをダウンロードします。

ダウンロードしたファイル(libsvm-2.9.tar.gz)を適当な場所(~/local)で解凍します。

$ tar xzvf libsvm-2.9.tar.gz
$ cd libsvm-2.9/python	
		

MacOSXの場合、Makefileの一部を変更します。

#LDFLAGS = -shared
# Mac OS
LDFLAGS = -framework Python -bundle			
		

また、そのままではsageで動かなかったので、svm.pyの以下の2点にfloatのキャストを 追加するように修正してください。

126c126
< 		svmc.svm_node_array_set(data,j,k,x[k])
---
> 		svmc.svm_node_array_set(data,j,k,float(x[k]))
138c138
< 			svmc.double_setitem(y_array,i,y[i])			
---
> 			svmc.double_setitem(y_array,i,float(y[i]))
		

sage内部のpythonにLIBSVMをインストールするには、sageのpythonで seup.pyを実行します。私は、sageを~/localにインストールしているので

$ ~/local/sage/local/bin/python setup.py install			
		
と実行しました。

正しく動くか、集合知の9.9.2の例題で確認します。 無事、同じ結果がでたので、次に進みます。

# svm.pyを2カ所修正した # 動作を確認するために、集合知の9.9.2の例題を実行する from svm import * prob = svm_problem([1,-1],[[1,0,1],[-1,0,-1]]) param = svm_parameter(kernel_type = LINEAR, C=float(10)) m = svm_model(prob, param) 
       
*
optimization finished, #iter = 1
nu = 0.025000
obj = -0.250000, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 2
*
optimization finished, #iter = 1
nu = 0.025000
obj = -0.250000, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 2
m.predict([1,1,1]) 
       
1.0
1.0

LIBSVMで簡単な例題を解く

LIBSVMへの入力は、各データどのクラスに属するかを示すclsと データのベクトル値datを引数に取ります。

# LIBSVM用にデータを加工する cls = v1 + v2; print cls dat = c1 + c2; print dat 
       
[-1, -1, -1, 1, 1, 1]
[[1, 2], [1, 4], [2, 4], [2, 1], [5, 1], [4, 2]]
[-1, -1, -1, 1, 1, 1]
[[1, 2], [1, 4], [2, 4], [2, 1], [5, 1], [4, 2]]
# 例題にSVMを適応 prob = svm_problem(cls, dat) m = svm_model(prob, param) 
       
*
optimization finished, #iter = 4
nu = 0.033333
obj = -1.000000, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 2
*
optimization finished, #iter = 4
nu = 0.033333
obj = -1.000000, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 2

データが正しく分類されているか赤(1,4)の座標を入力すると、-1.0が返され、 正しい結果が返ってきます。

m.predict([1, 4]) 
       
-1.0
-1.0

結果の表示

contour_plotを使って、SVMの予想をコンタマップで表示すると、$y = x$ の線でうまく、分類されているのが、分かります。

contour_plot(lambda x, y : m.predict([x, y]), (x, 0, 5), (y, 0, 5), figsize=5) 
       

さらに、各点(赤、青)の予想値と境界線の形状をlist_plot3で表示します。

# plot3dがうまく表示できないので、list_plot3dで代替 pl7 = list_plot3d([(x, y, m.predict([x, y])) for x in srange(0, 5, 0.1) for y in srange(0, 5, 0.1)]) pl8 = list_plot([(x, y, m.predict([x, y])) for (x, y) in c1], rgbcolor='red') pl9 = list_plot([(x, y, m.predict([x, y])) for (x, y) in c2], rgbcolor='blue') (pl8 + pl9 +pl7).show(xmin=0, xmax=5, ymin=0) 
       

SVMのカーネルメソッドのすごさを確かめるために点の値がチェスボードのように 格子状に部分布する点を分析することにします。

まずは、チェスボードのマスの値を返すchessBox関数を作成します。 ただしく、値がセットされるか、conour_plotでみてみましょう。

# チェスボックスの例 def chessBox(x, y): if ((int(x)+int(y))%2) == 0: return 1 else: return -1 #[chessBox(x, y) for x in range(0,5) for y in range(0,5)] contour_plot(chessBox, (x, 0, 5), (y, 0, 5), figsize=5) 
       

テストデータの作成

0から5の範囲にランダムに点を生成し、chessBox関数で赤(red)と青(blue) に振り分けます。

# ランダムな点を生成 rndPts = [[5*random(), 5*random()] for i in range(0,1000)]; 
       

Kernel関数LINEARの場合

kernel_type=LINEAR(線形を意味する)を使って分類すると、 うまく格子状のデータを分類することができません。

red = [pt for pt in rndPts if chessBox(pt[0], pt[1]) == -1]; blue = [pt for pt in rndPts if chessBox(pt[0], pt[1]) == 1]; redCls = [-1 for pt in red] blueCls = [1 for pt in blue] chessRedPlt = list_plot([(x, y) for (x, y) in red], rgbcolor='red') chessBluePlt = list_plot([(x, y) for (x, y) in blue], rgbcolor='blue') (chessRedPlt+chessBluePlt).show(xmin=0, xmax=5, figsize=5) 
       
# カーネルメソッドをLINEARだとうまく分けられない prob = svm_problem(redCls + blueCls, red + blue) param = svm_parameter(kernel_type = LINEAR, C=float(1000)) m = svm_model(prob, param) contour_plot(lambda x, y : m.predict([x, y]), (x, 0, 5), (y, 0, 5), figsize=5) 
       
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
.................................*......................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
...........................................*............................\
........................................................................\
........................................................................\
........................................................................\
............................................................*...........\
........................................................................\
....................*...................................................\
.................*......................................................\
........................................................................\
.......................*................................................\
...................................................................*....\
......................................................*.................\
........................................................................\
........................................................................\
........................................................................\
...............................*........................................\
........................................................................\
........................................................................\
............................................*...........................\
........................................................................\
........................................................................\
........................................................................\
........................................................................\
..................................................*
optimization finished, #iter = 3348299
nu = 0.974402
obj = -974232.337534, rho = -0.857296
nSV = 976, nBSV = 973
Total nSV = 976
.........................................................................................................................................................................................................................................................................................................................................................................................................*.........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................*................................................................................................................................................................................................................................................................................................................*.......................................................................................................*....................................................................*.....................................................................................................................................................*...................................................................................................................*..........................................................*........................................................................................................................................................................................................................................................................*....................................................................................................................................................................................................................................*.............................................................................................................................................................................................................................................................................................................................................................................*
optimization finished, #iter = 3348299
nu = 0.974402
obj = -974232.337534, rho = -0.857296
nSV = 976, nBSV = 973
Total nSV = 976

Kernel関数RBFの場合

次に、kernel_type=RBF(ガウシアンカーネル関数)を使って分類すると、 なんとなく格子状に分類しているように見えます。

ガウシアンカーネル関数は、 $$ K(x, x') = e ^{\left( - \frac{| x - x' |^2}{\sigma^2} \right)} $$

# カーネルメソッドをRBFにするとある程度格子の形が見て取れる prob = svm_problem(redCls + blueCls, red + blue) param = svm_parameter(kernel_type = RBF, C=float(1000)) m = svm_model(prob, param) contour_plot(lambda x, y : m.predict([x, y]), (x, 0, 5), (y, 0, 5), figsize=5) 
       
...............................................*.......................*\
....................*
optimization finished, #iter = 90313
nu = 0.213203
obj = -161853.152563, rho = 9.646479
nSV = 238, nBSV = 187
Total nSV = 238
...............................................*.......................*....................*
optimization finished, #iter = 90313
nu = 0.213203
obj = -161853.152563, rho = 9.646479
nSV = 238, nBSV = 187
Total nSV = 238

簡単な画像認識

最後に簡単な画像認識をLIBSVMを使って確かめてみます。

5x5のマス目に数字の0から4までを書いた画像5個(本当に少ないです!) をテストデータとして使用します。

# 文字認識 m0 = [[0,1,1,1,0], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [0,1,1,1,0]] m1 = [[0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0]] m2 = [[0,1,1,1,1], [1,0,0,1,0], [0,0,1,0,0], [0,1,0,0,0], [1,1,1,1,1]] m3 = [[0,1,1,1,0], [1,0,0,0,1], [0,0,1,1,0], [1,0,0,0,1], [0,1,1,1,0]] m4 = [[0,0,1,0,0], [0,1,0,0,0], [1,0,0,1,0], [1,1,1,1,1], [0,0,0,1,0]] p0 = flatten(m0); p1 = flatten(m1); p2 = flatten(m2); p3 = flatten(m3) p4 = flatten(m4) v = [0,1,2,3,4] 
       

作成したデータをcontour_plotを使って表示して、確認します。 なんとなくそれらしく見えるでしょう。(笑)

# 画像作成用の関数 def mesh(x, y, tbl): idX = int(x) idY = int(4.9-y) return tbl[idY][idX] # 学習用画像を表示 trn0 = contour_plot(lambda x, y : mesh(x, y, m0), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) trn1 = contour_plot(lambda x, y : mesh(x, y, m1), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) trn2 = contour_plot(lambda x, y : mesh(x, y, m2), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) trn3 = contour_plot(lambda x, y : mesh(x, y, m3), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) trn4 = contour_plot(lambda x, y : mesh(x, y, m4), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) html.table([[trn0, trn1, trn2], [trn3, trn4]]) 
       

画像認識

いよいよ画像認識の準備が整いました。 えいや〜で、モデルを作成します。

# 画像認識 param = svm_parameter(kernel_type = RBF, C=float(10)) prob = svm_problem(v,[p0,p1,p2,p3,p4]) m = svm_model(prob, param) 
       
*
optimization finished, #iter = 1
nu = 0.246622
obj = -2.466216, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.676333
obj = -6.763327, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.330771
obj = -3.307713, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.202682
obj = -2.026823, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 5
*
optimization finished, #iter = 1
nu = 0.246622
obj = -2.466216, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.676333
obj = -6.763327, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.330771
obj = -3.307713, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.202682
obj = -2.026823, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 5

テスト用の画像

テストに使う画像は、以下の3個です。 判別結果は、ただしく0, 1, 1となっています。(めでたし、めでたし)

# テスト用の画像を作成する m0_1 = [[0,1,1,1,0], [0,1,0,0,1], [0,1,0,0,1], [0,1,0,0,1], [0,1,1,1,0]] m1_1 = [[0,0,1,0,0], [0,1,1,0,0], [0,0,1,0,0], [0,0,1,0,0], [0,1,1,1,0]] m1_2 = [[0,0,0,1,0], [0,0,0,1,0], [0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0]] t0_1 = flatten(m0_1) t1_1 = flatten(m1_1) t1_2 = flatten(m1_2) 
       
# テスト用画像の表示 cnt1 = contour_plot(lambda x, y : mesh(x, y, m0_1), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) cnt2 = contour_plot(lambda x, y : mesh(x, y, m1_1), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) cnt3 = contour_plot(lambda x, y : mesh(x, y, m1_2), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23)) html.table([[cnt1, cnt2, cnt3]]) 
       
# テスト用画像の認識 print m.predict(t0_1) print m.predict(t1_1) print m.predict(t1_1) 
       
0.0
1.0
1.0
0.0
1.0
1.0

念のため、学習に使ったデータが正しく識別されるかもみてみました。

こんなに少ないデータでよく認識できるものだと感心しました。これがSVMのすごさなのでしょうか?

# 学習用画像の認識 print m.predict(p0) print m.predict(p1) print m.predict(p2) print m.predict(p3) print m.predict(p4) 
       
0.0
1.0
2.0
3.0
4.0
0.0
1.0
2.0
3.0
4.0