Rでニューラルネットワーク(2変数の関数近似)
で、1変数の関数近似がうまくいったので、調子にのって2変数の関数近似にもチャレンジしてみました。
2変数のsinc関数を、ニューラルネットワークの誤差伝播法を使って近似する(しようとする)ものです。
library(animation) #number of training data N <- 30 #number of hidden unit s <- 5 #learning rate parameter l <- 0.01 #standard deviation of training data sd <- 0.05 #count of loop LOOP <- 10000 #frame interval of animation movie INTERVAL <- 0.1 #total time of animation movie TOTAL <- 20 #initialize weight parameter w1 <- matrix(seq(0.1,length=s*3, by=0.1),s,3) w2 <- matrix(seq(0.1,length=s+1, by=0.1),1,s+1) #generate traning data xt <- seq(-10,10,length=N) yt <- xt f <- function(x, y) { r <- sqrt(x^2+y^2) temp = function (r) { if ( r == 0 ) { return(10) } else { return(10*sin(r)/r) } } sapply(r, temp) } zt <- outer(xt, yt, f) persp(xt,yt,zt,theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA) neuro_func_ <- function (x, y) { a <- w1 %*% c(1, x, y) z <- tanh(a) y <- w2 %*% c(1, z) return(y) } #vectorize neuro_func_ neuro_func <- function (x, y) { mapply(neuro_func_, x, y) } persp(xt, yt, outer(xt, yt, neuro_func), theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA, axes=TRUE, ticktype="detailed") #main function to train the neural network neuro_training <- function () { trigger_count <- as.integer(LOOP / (TOTAL / INTERVAL)) for (k in 1:LOOP) { if (k %% trigger_count == 1) { persp(xt, yt, outer(xt, yt, neuro_func), theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA, axes=TRUE, ticktype="detailed") } for (i in 1:N) { for (j in 1:N) { # forward propagation a <- w1 %*% c(1, xt[i], yt[j]) b <- tanh (a) z <- w2 %*% c(1, b) # back propagation d2 <- z - zt[i,j] w2 <<- w2 - l * d2 * c(1, b) d1 <- d2 %*% (1-tanh(t(a))^2) * t(w2[2:(s+1)]) w1 <<- w1 - l * t(d1) %*% c(1, xt[i], yt[j]) } } } } saveMovie(neuro_training(), interval=INTERVAL, moviename="neural_network_training_movie", movietype="mpg", outdir=getwd(), width=640, height=480) cat("learning end print w1 and w2\n") print(w1) print(w2) persp(xt, yt, outer(xt, yt, neuro_func), theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA, axes=TRUE, ticktype="detailed")
で、結果が以下のムービー。
サンプルデータを10000回ループで読み込ませているんですが、途中から微動だにしなくなってしまいました・・・。
何かパラメータの設定が悪いのか、もっとループを多くしなければいけないのか、まだ原因は分かっていません。。。どなたか知識のある方がこれを読んで頂けたら指摘してもらえると嬉しいです。