ニューラルネットワークで画像認識
ニューラルネットワークの簡単な関数近似プログラムを先日書いたので、今は画像認識プログラムを書いてますが、ものすごく簡単なバージョンが出来上がったので晒しておきます。
C++で画像解析部分を作って、Rで訓練データの学習、テストデータの判定をしています。
MNISTの手書き文字から学習データ準備
まずは、インプットとなる画像のデータですが、定番のMNISTの手書き数値データを使います。
元々のIDXフォーマットというフォーマットは再利用性が無さそうなので、http://www.cs.toronto.edu/~roweis/data.htmlから既にJPEG化されたものを引っ張ってきてそれを解析します。
OpenCVを使って2値化+CSV出力する簡単なプログラム。
#include "cv.h" #include "highgui.h" #include <stdio.h> #include <stdlib.h> const int kDIGIT_W = 28; const int kDIGIT_H = 28; int main (int argc, char* argv[]) { IplImage *src_img, *dest_img; char* imgfile = argv[1]; int answer = atoi(argv[2]); src_img = cvLoadImage(imgfile, CV_LOAD_IMAGE_GRAYSCALE); dest_img = cvCreateImage(cvGetSize(src_img), IPL_DEPTH_8U, 1); // make src image gray scale cvThreshold(src_img, dest_img, 90, 255, CV_THRESH_BINARY); int width = dest_img->width; int height = dest_img->height; int xCount = width / kDIGIT_W; int yCount = height / kDIGIT_H; for (int x=0; x<xCount; x++) { for (int y=0; y<yCount; y++) { int offsetX = x*kDIGIT_W; int offsetY = y*kDIGIT_H; printf("%d", answer); for (int i=0; i<kDIGIT_W; i++) { for (int j=0; j<kDIGIT_H; j++) { int val = dest_img->imageData[(j+offsetY)*dest_img->widthStep+(i+offsetX)]; printf(",%d", val*-1); } } printf("\n"); } } cvReleaseImage(&src_img); cvReleaseImage(&dest_img); return 0; }
これで、以下のように実行すると画像データが格納されたCSVが出来上がります。
1行が1文字を表わしており、1列目が正解の値(数値)、2列目以降が二値化された画像データです。
$ g++ -g -I/usr/include/opencv make_test_data.cpp -o make_test_data -lcxcore -lcv -lhighgui -lcvaux -lml #訓練データの準備 $ ./make_test_data mnist_train0.jpg 0 > mnist_train_all.txt $ ./make_test_data mnist_train1.jpg 1 >> mnist_train_all.txt $ ./make_test_data mnist_train2.jpg 2 >> mnist_train_all.txt .... $ ./make_test_data mnist_train9.jpg 9 >> mnist_train_all.txt #テストデータの準備 $ ./make_test_data mnist_test0.jpg 0 > mnist_test_all.txt $ ./make_test_data mnist_test0.jpg 1 >> mnist_test_all.txt .... $ ./make_test_data mnist_test0.jpg 9 >> mnist_test_all.txt
学習プログラム
次にこのCSVデータをインプットするニューラルネットワークの学習プログラムを組みます。
ニューラルネットの設定は以下です。
- 基本的な構造:入力が784次元、出力が10次元、隠れユニット層を1つ含む2層ネットワーク(PRMLで言うところの)
- 隠れユニット数:30
- 学習パラメータ:0.05
- 学習方法:オンライン学習。ただし、1つのエポックでは60338件の訓練データ集合の中から1000件のデータをランダム抽出し、それを1000エポック繰り返す。
- 出力ユニットの活性化関数:ソフトマックス関数
- 誤差関数:交差エントロピー誤差関数(ただし、プログラム中は活性に対する微分であるyk-tkのみ使われているため、交差エントロピー関数そのものは出てこない。)
#number of hidden unit s <- 30 #learning rate parameter l <- 0.05 #count of loop MAX_EPOCH <- 1000 #read traning data digit_data <- read.csv("train/mnist_train_all.txt",header=F) ##initialize weight parameter #weight parameter between input and hidden units w1 <- matrix(runif(s*length(digit_data), -1, 1), s, length(digit_data)) #weight parameter between hidden units and output w2 <- matrix(runif(10*(s+1), -1, 1), 10, s+1) #definition of function logsumexp <- function (x) { m <- max(x) m + log(sum(exp(x-m))) } softmax <- function (a) { sapply(a, function (x) { exp(x-logsumexp(a)) }) } neuro_func <- function (input) { a1 <- w1 %*% c(1, input) z1 <- tanh(a1) a2 <- w2 %*% c(1, z1) z2 <- softmax(a2) return(z2) } ##train the neural network for (k in 1:MAX_EPOCH) { #sample 1000 training data sample_index = sample(1:nrow(digit_data), 1000) for (i in sample_index) { tmp <- digit_data[i,] #target data t <- rep(0, 10) t[as.integer(tmp[1]+1)] <- 1 input <- t(tmp[2:length(tmp)]) #forward propagation a1 <- w1 %*% c(1, input) z1 <- tanh(a1) a2 <- w2 %*% c(1, z1) z2 <- softmax(a2) #back propagation d2 <- z2 - t w2 <<- w2 - l * d2 %*% t(c(1, z1)) d1 <- d2 %*% w2[,2:(s+1)] * (1-tanh(t(a1))^2) w1 <<- w1 - l * t(d1) %*% t(c(1, input)) } } save(w1, w2, logsumexp, softmax, neuro_func, file="nnet_image.rdata")
テストプログラム
準備したテストデータを読み込んで、ネットワークの出力のうち、最大の出力を持つユニットをネットワークが判断した数値とみなしています。
#read trained parameter load("nnet_image.rdata") #read traning data digit_data <- read.csv("test/mnist_test_all.txt",header=F) ##main function to train the neural network correct <- 0 wrong <- 0 for (i in 1:nrow(digit_data)) { tmp <- digit_data[i,] # target number of class answer <- tmp[1] t <- rep(0, 10) t[as.integer(answer+1)] <- 1 input <- t(tmp[2:length(tmp)]) # forward propagation z <- neuro_func(input) zt <- which.max(z) - 1 if (zt == answer) { correct <- correct + 1 } else { wrong <- wrong + 1 } } cat(paste("correct case=",correct,", wrong case=", wrong, "\n"))
結果とまとめ
テストデータ集合10153件中9360件正解、ということで正解率は約92.5%でした。
(correct case= 9360 , wrong case= 793 )
自分の手書き文字でも試してみましたが、まあ意地悪いことをしなければ大体正解しているみたいです。
ただ、他の方の結果と比較すると、id:ultraistさんの結果やhttp://yann.lecun.com/exdb/mnist/の例では、2層ネットワークで97%近い正解率まで達しているので、まだまだ。
id:ultraistさんは隠れユニット数が512で、LeCun氏は300とか1000とかなので、僕の隠れユニット数30はちょっと少なすぎたか、という印象。ただ、入力ベクトルが784次元なので、それに対し隠れユニットが1000はいくらなんでも多い、というか一般化できていないんじゃないかそれ?という感じなのだが、どうなんだろう。
とりあえず始めたばかりなので、TODOを列挙。
- 隠れユニット数をいじってみて比較
- 学習パラメータをいじってみて比較
- deskew処理
- オブジェクト抽出処理
- 画素以外の特徴ベクトルを入力にして学習
- 2層ネットワーク以外のやり方と比較
と見えているだけでもまだまだやることてんこ盛りです・・・。地道にやっていきます。
2009/9/13追記
学習時間について
この例ではのべ100万件のデータを学習していますが、全部で丸2〜3日ほどかかりました。