読者です 読者をやめる 読者になる 読者になる

遥かへのスピードランナー

シリコンバレーでAndroidアプリの開発してます。コンピュータービジョン・3D・アルゴリズム界隈にもたまに出現します。

Rでニューラルネットワーク(1変数の関数近似)

PRML

機械学習・パターン認識方面の勉強初めてから4ヶ月ほど立ちました。最近はnaoya_tさん主催のPRML読書会に参加させて頂いています。

来週末8/29の第6回読書会ではニューラルネットワークの章の発表を担当することになったので、Rを使ってサンプルプログラムを組んでみました。参考にしたのはPRML5.1〜5.3の範囲で、sin関数からサンプリングした点データをニューラルネットワークを使って誤差逆伝播法で学習し、元のsin関数を近似します。

学習前の初期状態が以下の図。赤字が元の関数(線)およびサンプルデータ(点)で青字がニューラルネットワークの出力です。

で、学習後の状態が以下です。

いい感じに再現できています。
以下ソースコード

library(animation)

#number of training data
N <- 50

#number of hidden unit
s <- 5

#learning rate parameter
l <- 0.05

#standard deviation of training data
sd <- 0.05

#count of loop
LOOP <- 7000

#frame interval of animation movie
INTERVAL <- 0.1

#total time of animation movie
TOTAL <- 25

#initialize weight parameter
w1 <- matrix(seq(0.1,length=s*2, by=0.1),s,2)
w2 <- matrix(seq(0.1,length=s+1, by=0.1),1,s+1)

#generate traning data
xt <- seq(0,1,length=N)
yt <- sin(2*pi*xt)+rnorm(N,sd=sd)

neuro_func_ <- function (x) {
  a <- w1 %*% c(1, x)
  z <- tanh(a)
  y <- w2 %*% c(1, z)
  return(y)
}
#vectorize neuro_func_
neuro_func <- function (x) {
  sapply(x, neuro_func_)
}

png(filename = "graphic1.png", width = 480, height = 480, pointsize = 18, bg = "white")
plot(function(x) {return(sin(2*pi*x))}, xlab="input", ylab="output", xlim=c(0, 1), ylim=c(-2.5,2.5), sub=paste("S=", s, ",N=", N, ",l=", l), col=2)
plot(neuro_func, add=T, col=4, xlim=c(0,1))
points(xt, yt, col=6)

#main function to train the neural network
neuro_training <- function () {
  trigger_count <- as.integer(LOOP / (TOTAL / INTERVAL))
  for (k in 1:LOOP) {
    if (k %% trigger_count == 1) {
      plot(function(x) {return(sin(2*pi*x))}, xlab="input", ylab="output", xlim=c(0, 1), ylim=c(-2.5,2.5), main=paste("LOOP=", k), sub=paste("S=", s, ",N=", N, ",l=", l), col=2)
      plot(neuro_func, add=T, col=4, xlim=c(0,1))
      points(xt, yt, col=6)
    }
    
    for (i in 1:N) {
      # forward propagation
      a <- w1 %*% c(1, xt[i])
      z <- tanh (a)
      y <- w2 %*% c(1, z)

      # back propagation
      d2 <- y - yt[i]
      w2 <<- w2 - l * d2 * c(1, z)
      d1 <- d2 %*% (1-tanh(t(a))^2) * t(w2[2:(s+1)])
      w1 <<- w1 - l * t(d1) %*% c(1, xt[i])
    }
  }
}

saveMovie(neuro_training(), interval=INTERVAL, moviename="neural_network_training_movie", movietype="mpg", outdir=getwd(), width=640, height=480)

cat("learning end.print w1 and w2\n")
print(w1)
print(w2)

png(filename = "graphic2.png", width = 480, height = 480, pointsize = 18, bg = "white")
plot(function(x) {return(sin(2*pi*x))}, xlab="input", ylab="output", xlim=c(0, 1), ylim=c(-2.5,2.5), sub=paste("S=", s, ",N=", N, ",l=", l), col=2)
plot(neuro_func, add=T, col=4, xlim=c(0,1))
points(xt, yt, col=6)

収束する様子をムービーにしてみました。
中央上部に表示されている数値はサンプル点学習のループ回数です。

考察・雑感

学習パラメータは0.05固定としていますが、いろいろといじってみて、この値を大きくしすぎると収束せずに発散したり、振動したりするので、これくらいの値が妥当と判断しました。学習するデータ集合の数を多くした場合は、一つ一つの学習におけるの更新量は小さくするべきなので、これらの学習パラメータも小さくすることが望ましいはずです。
また、いずれにしても精度のよい近似が得られるまではデータ集合を繰り返し学習する必要があります。上記の動画を見ても、視覚的に精度のよい近似に近づくまでは1000回程度の学習ループを必要としています。学習パラメータを、固定ではなく、学習データに基づくパラメータに依存させたりすることでこの近似のスピードをもう少し早めることが出来るのかも知れません。PRMLの後の方で出てくるのかな?

参考文献

R全般に関して、id:syou6162氏のはてダを多いに参考にさせて頂きました。出し惜しみなく情報を晒して頂いてありがとうございます。
あと、Rでアニメーションを作るのは以下のエントリを参考にしました。

やろうと思ったことがほぼ何でもできちゃうR、ステキです!!。