読者です 読者をやめる 読者になる 読者になる

遥かへのスピードランナー

シリコンバレーでAndroidアプリの開発してます。コンピュータービジョン・3D・アルゴリズム界隈にもたまに出現します。

Rでガウス過程による分類を実装

PRML

PRMLの6.4.5〜6.4.6の範囲にあるガウス過程による分類をRで実装してみました。

ソースコード全文はgithubにアップしています。
http://github.com/thorikawa/prml/blob/master/gaussian_process_classify.R

ここでは例として、(1,0),(2,0),(3,0)で1、(0,1),(0,2),(0,3)で0の値を取る訓練集合を用いています。

# Training data
x=list(c(1,0),c(2,0), c(3,0), c(0,1), c(0,2), c(0,3))
t=c(1,1,1,0,0,0)
training_data_num <- length(x)

この訓練集合とカーネル関数をもとに予測分布を導出しています。
ガウス過程においては、訓練集合から予測分布を決める(ほぼ)唯一の要素はカーネル関数になるということなのは分かるのですが、カーネル関数によってどのように予測分布が変わってくるのかという点は良く理解できていません。
この辺をまずは感覚でつかむために、いくつか代表的なカーネルで予測分布を求めてみました。

ガウスカーネル

exp(-||{\bf x}-{\bf x'}||^2/(2\sigma^2))で与えられるカーネル。
\sigmaの値を0.2,0,4と変えてみると、id:n_shuyoさんの考察通り、\sigmaは訓練データを中心とするガウス分布の尖り具合を表現していることが分かります。

# Gaussian kernel
sigma <- 0.2
kernel <- function (x1, x2) {
  exp(-sum((x1-x2)*(x1-x2))/(2*sigma^2));
}
gp(kernel, "gaussian kernel", "sigma=0.2");


# Gaussian kernel
sigma <- 0.4
kernel <- function (x1, x2) {
  exp(-sum((x1-x2)*(x1-x2))/(2*sigma^2));
}
gp(kernel, "gaussian kernel", "sigma=0.4");


線形カーネル

{\bf x}^T{\bf x'}で与えられるカーネル。単純な線形分離になります。

# Linear kernel
kernel <- function (x1, x2) {
  sum(x1*x2);
}
gp(kernel, "linear kernel");


指数カーネル

\theta||{\bf x}-{\bf x'}||で与えられるカーネル。
\sigmaの値を1.0,0.3と変えてみると、ガウス分布と同様、\sigmaは指数分布の尖り具合を表していると言えそうです。

# Exponential kernel
theta <- 1.0
kernel <- function (x1, x2) {
  exp(-theta*sqrt(sum((x1-x2)*(x1-x2))));
}
gp(kernel, "exponential kernel", "theta=1.0");


# Exponential kernel
theta <- 0.3
kernel <- function (x1, x2) {
  exp(-theta*sqrt(sum((x1-x2)*(x1-x2))));
}
gp(kernel, "exponential kernel", "theta=0.3");


多項式カーネル

({\bf x}^T{\bf x'}+1)^Mで与えられるカーネル。
M=2,M=3と次数を変えてみると、、、これは、、、どういう特徴を持つのかいまいち不明です。

# Polynomial kernel
kernel <- function (x1, x2) {
  (sum(x1*x2)+1)^2;
}
gp(kernel, "polynomial kernel", "degree=2");


# Polynomial kernel
kernel <- function (x1, x2) {
  (sum(x1*x2)+1)^3;
}
gp(kernel, "polynomial kernel", "degree=3");

まとめ

というわけで理解は浅いのですが、様々なカーネル関数でガウス過程による分類を実装してみました。
カーネルの役割は、訓練集合における任意の2つのデータに対する、出力の類似度をどう評価するかを決定することだと理解しています。
ガウスカーネルや指数カーネルは、RBFであり、中心からの距離に対して反比例する形で値が減少していくため、入力値が近ければ近いほどその出力の値も近い値を示し、予測分布は訓練集合を中心とした分かりやすい分布になります。
線形や多項式カーネルはRBFではないので直感的な理解が現時点では出来ていません。カーネル多変量解析を熟読すればこのあたりも理解できるようになるかなあ。