SVMとは?Scikit-learnを使ってSVMでクラス分類する方法を解説

こんにちは三谷です。
今回は、AI(人工知能)の一つ、機械学習においてとても有名なアルゴリズムを紹介したいと思います。
機械学習を行う上で、SVM(サポートベクターマシン)はデータの分類をするためにも重要な手法です。
SVM(サポートベクターマシン)の使い方について徹底解説します!
Scikit-learnを初めて使う方でもわかりやすく解説しますので、是非チャレンジしてみてください。
Scikit-learnについて復習したいという方は、まずはこちらの記事を読んでからこの記事に戻ってきましょう。

SVM(サポートベクターマシン)とは?

SVM(サポートベクターマシン)は、教師あり学習のクラス分類と、回帰のできる機械学習アルゴリズムです。
それぞれSVC(Support Vector Classification)、SVR(Support Vector Regression)と書かれることもあります。
SVM(サポートベクターマシン)は、少ない教師データで高い汎化性能を持てることが特徴で、計算も早く過学習も起こしづらいです。
使い勝手が良いため、今でも様々な分野で活用されているアルゴリズムです。
ただし、データがばらついたり偏ると、計算量が膨大になったり(次元の呪い)、学習が非常に非効率なため、データのサンプル数が多い場合(100,000サンプル以上)はメモリ使用量や実行において難しくなるデメリットがあります。

SVM(サポートベクターマシン)の仕組みを解説

SVM(サポートベクターマシン)は、パーセプトロンに「カーネル関数」と「マージン最大化」を加えて次元を増やすことで、非線形の分割を線形に分割できるようにしているアルゴリズムです。
SVM(サポートベクターマシン)を解説していく上で、「カーネル関数」と「マージン最大化」は重要な概念です。
境界に最も近いサンプルとの距離(マージン)が最大となるような超平面で分離する計算をしています。
言葉だけで説明していると難しいと思いますのでひとつずつ紐解いていきましょう。

線形のモデルの解決方法

このようなデータを用意してみました。
まずは、下に用意した図を観てください。
図にプロットしてあるオレンジの点と青い点を分類する問題があると考えてください。
これらの点を学習することで、いい感じにオレンジと青を分類できるようになったモデルは、新しい点が打たれた際にその点が青なのかオレンジなのかを当てることができるようになるはずです。
この考え方が基本的なクラス分類問題の考え方です。

SVM用のサンプルデータ

一番シンプルな分類問題の解決方法は、線形モデルと言われるものです。
線形モデルはこの2つの点を分類するのに、直線を引くことで分けようとすると思ってください。
データ分類のために線を引けば、分類線以上のデータと分類線以下のデータに分類することができます。
単純に分類すべきデータが2つに分かれていれば線形分離は容易にできます。
しかし、今回のデータでは直線では分離することはできません。
なぜなら、無理に線を引こうとすると図の通り、青い丸とオレンジの三角のいデータに被ってしまうからです。
無理やり線を引いても、オレンジと青は分離できません。
このことを、線形分離できない、といいます。

SVM用のサンプルデータのエラー分類線

カーネル関数の考え方とは?

そこで、このような分離をうまくできるように、まず「カーネル関数」
という考え方が使われます。
カーネル関数とは、非線形の特徴量をデータ表現に加えることで次元を増やし、分離をする際に使用されるアルゴリズムです。
入力の特徴量を拡張する意味で、2番めの特徴量の2乗を新しい特徴量として加えてみます。
先ほどまで平面でしか表現できなかったデータが、立体的に表現できるようになりました。

SVM用のデータを立体的に変更

以下の図を確認してください。
こうすると、先程まで直線によって分離できなかった問題が、「平面」によって分離できるようになります。

SVM用のデータを立体的にして平面で分類線を引く

このように、入力特徴量の次元を増やして線形分離できるようにしたのが、SVM(サポートベクターマシン)の基本的な考え方です。
このカーネル関数を適用すると、平面では線形分離できなかったデータが線形分離できるようになります。
しかし、カーネル関数を導入すると、実際にデータポイントの拡張を計算し始めますので、高次元で計算量が多くなりすぎる問題が発生します。
そのため、数学的に解決するカーネルトリックという手法が用いられる事が多いです。

SVM(サポートベクターマシン)の特徴マージン最大化について解説

SVM(サポートベクターマシン)のもう一つの特徴、「マージン最大化」についてもご紹介しておきましょう。
例えばこのようなデータを赤と青の点に分けることを考えます。

SVM用の平面データ

直線の引き方は幾通りも考えられるのがおわかりいただけますでしょうか。
どちらの線でも、2つの点を分類できています。
つまり、データを分類できる線形は幾通りも引けてしまうのです。
場合によっては、三角のデータに近い線が引けてしまいますし、逆に丸のデータに近い線も引けてしまいます。
そうなると、データに偏りができてしまいます。

SVM用のデータには無数に線が引ける

マージンとは線形分離する線からデータまでの距離

最終的にほしい分離の線は、「新しい青または赤の点が入力されたときに、正しく(汎用的に)赤か青かを分類できる線」ですので、どちらのほうがいい線なのかを考える必要があります。
このときに使われるのが、「マージン」という考え方です。マージンとは、線形分離する線からそれぞれのデータまでの距離です。
このマージンが最大化するように分離の線を決めることをマージン最大化と呼びます。

SVM用のデータに最適な分類線を入れる

ちなみに、一番近いデータ点を「サポートベクトル」と言います。
このような考え方で分離が行われるため、ディープラーニングで使われるニューラルネットワークに比べると、少ないデータでも汎用性が高い、つまり汎化性能が高いモデルができるのがSVM(サポートベクターマシン)の特徴となっているのです。
SVM(サポートベクターマシン)の特徴や考え方を簡単に解説している動画もありますのでぜひ参考にしてください。

Scikit-learnでSVM(サポートベクターマシン)を実装する方法

ここからは、Scikit-learnで、実際にSVM(サポートベクターマシン)を試してみましょう。
インストールがまだの方は、インストールしてみてください。
Pythonの機械学習ライブラリなので誰でも無料で利用可能です。
Pythonの機能を把握したい場合はこちらの記事を参考にしてください。

Scikit-learnの実装方法について学ぼう

[Create New Project]で新しいプロジェクトを作成します。

SVMを利用するためのScikit-learnの実装方法

Locationに「opencv」と入力し[Create]をクリックします。Pycharmでは、プロジェクトという単位でプログラムを管理することができます。プロジェクトは、指定したLocationのディレクトリに作られたフォルダです。

SVMを利用するためのScikit-learn実装方法画面1

[File]-[New]で新しくファイルを作成します。

SVMを利用するためのScikit-learn実装方法画面2

ダイアログボックスから「Python File」を選びます。

SVMを利用するためのScikit-learn実装方法画面3

「Name」に「svm.py」と入力し[OK]をクリックします。

SVMを利用するためのScikit-learn実装方法画面4

ライブラリをインポートします。

from sklearn import datasets
from sklearn import svm
import matplotlib.pyplot as plt
from sklearn import metrics

こちらは、デフォルト(default)画面です。
上から、データセットのインポート、SVM(サポートベクターマシン)のインポート、グラフ描画用ライブラリのmatplotlibのインポート、結果を混同行列として確認するためのmetricsのインポートです。

データの準備と可視化

今回は、手書の数字が含まれたMNISTデータセットを利用します。

#データの準備
digits = datasets.load_digits()

n_samples = len(digits.data)
print("データ数:{}".format(n_samples))

#データの可視化
images_and_labels
= list(zip(digits.images, digits.target))

for index, (image, label) in enumerate(images_and_labels[:10]):
 plt.subplot(2, 5, index + 1)
 plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
 plt.axis('off')
 plt.title('Training: %i'% label)
plt.show()

MNISTとは、Mixed National Institute of Standards and Technology databaseの略で、0~9までの手書きの数字画像が含まれたデータセットです。
学習用のデータが60,000個、テスト用のデータが10,000個の、合計70,000個の手書き文字データが含まれています。

SVMを利用するためのデータセット

datasets.load_digits()で、簡単にMNISTデータセットを呼び出すことができます。
最後のfor文は、1797データが含まれているデータセットのうち、始めの10個だけをビジュアライズするコードになっています。
ここまで実行してみると、以下のような出力になるはずです。

SVMを利用するためのトレーニング用データセット

SVM(サポートベクターマシン)の設定

いよいよここからSVM(サポートベクターマシン)の設定です。

# SVM の読み込み
clf = svm.SVC(gamma=0.001, C=100.)

gammaとCが、調整が必要なハイパーパラメーターです。
gammaは、高次元空間へのマップ方法である放射基底関数(RBF:radial basis function)カーネルとも呼ばれるガウシアンカーネルの計算時に使用される、幅を制御する調整用のパラメーターです。gammaが小さいとガウシアンカーネルの直径が大きくなり、多くの点を近いと判断するようになり、gammaが大きいとデータポイントを重視するようになり、モデルがどんどん複雑になります。

SVMの非線形の数式表現方法1

Cは、正則化パラメーターです。これは、誤分類をどれだけ許容するかを決めるパラメーターで、大きく設定するほど誤分類をしないように分離が行われます。

SVMの非線形の数式表現方法2

学習を実行するコードの記述

学習を実行するコードを記述します。
今回は、全データの60%を使用して学習を行いました。

# 60%のデータで学習実行
clf.fit(digits.data[:int(n_samples * 6 / 10)], digits.target[:int(n_samples * 6 / 10)])

なんと、1行で済んでしまいます!
Scikit-learnにはSVM(サポートベクターマシン)のアルゴリズムが既に実装されているため、複雑な計算式は書かなくても、ハイパーパラメーターと学習実行さえ行えば機械学習ができるようになっているのです。
clf.fit(入力データ, ラベルデータ)と指定することで学習が実行されます。
残っている40%のデータを使って、テストを実行するコードを記述します。

# 40%のデータでテスト
expected = digits.target[int(n_samples *-4 / 10):]
predicted = clf.predict(digits.data[int(n_samples *-4 / 10):])

print("Classification report for classifier 
 %s:¥n%s¥n"% (clf,metrics.classification_report(expected, predicted)))
print("Confusion matrix:¥n%s" % metrics.confusion_matrix(expected, predicted))

実行すると、以下のような表が表示されます。
この表は、Scikit-learnで用意されているmetrics機能を使用したもので、学習済みモデルの評価ができます。
学習済みモデルが予測をした結果がどの程度の評価で認識できているかを示してくれます。

SVMの非線形の数式と評価基準

Scikit-learnによってSVMによる結果データ

また、下に表示されるのは混同行列(Confusion Matrix)です。行(正解ラベル)に0~9の手書き数字、列(予測ラベル)にも0~9の手書き数字があるとして、正解数が表示されます。
例えば、2行目を確認すると、1という手書き数字に対して1と予測したものが70個あり、2と間違えてしまったものが1個、8と間違えてしまったのが2個ある、という意味になります。

Scikit-learnによるSVMのエラー結果の割合

予測結果の可視化

こちらは必須ではないですが、予測結果を可視化してみます。

#予測結果を可視化
images_and_predictions = list(zip(digits.images[int(n_samples *-4 / 10):], predicted))
for index,(image, prediction) in enumerate(images_and_predictions[:12]):
 plt.subplot(3, 4, index + 1)
 plt.axis('off')
 plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
 plt.title('Prediction: %i' % prediction)
plt.show()

出力結果は以下のようになります。

Scikit-learnによるSVMのデータ抽出
Scikit-learnでは、SVM(サポートベクターマシン)によるデータ分類のため、数多くのデータ呼び出し、さらにデータの整理もしなければいけません。
混乱してしまうこともあるかもしれませんのでタブ(Tab)を上手く利用して、自分自身が解りやすいように整理するようにしましょう。

SVMとは まとめ

いかがだったでしょうか?
単純なモデルでしたが、Scikit-learnのSVM(サポートベクターマシン)を使ったクラス分類の方法をご理解いただけたのではないでしょうか。
SVM(サポートベクターマシン)は、「カーネル関数」と「マージン最大化」の概念を理解すれば、機械的にデータを入力していくだけです。
まずは、「カーネル関数」と「マージン最大化」の考え方を理解しましょう。
ハイパーパラメーターの調整など、是非いろいろ試してみてください。

SVMとは?徹底解説
最新情報をチェックしよう!
>企業向けAI人材育成サービス

企業向けAI人材育成サービス

AI事業発足やAI導入に必要な人材育成のステップとAI研究所が提供するサービス。AI研究所の人材育成サービスでは、3つのステップを軸に御社の業務内でAIを活用できる人材育成やAIプロジェクトの支援を行います。

CTR IMG