Blender+SIO2で3D物理エンジン
すっごい間が空いてしまい、いまさらなんですが、年末年始のプチ成果まとめ第2段です。前回のBox2Dに引き続き、今度は3Dの物理エンジンをいじってみます。
Box2Dの場合、貧弱ではありますが線画機能(DebugDraw)が内蔵されていますが、3Dの場合はレンダリング(描画)は部分はそれ専用のエンジンを用いることが多いです。逆に物理エンジンの部分はそれだけに特化していて、描画機能はもたないことが多いです。
というわけでフリーで使える物理エンジンとレンダリングエンジンを分けて紹介すると、
物理エンジンでは
あたりが有名、レンダリングエンジンでは、
が有名どころだと思います。
今回はその中でも「blender」を使ってみます。blenderは物理エンジンとして、デフォルトでbullletを内蔵しており、レンダリングから物理シミュレーションまで、一つのGUI上で簡単に編集・表現することができます。
またSIO2というiPhone用の3D物理エンジンに簡単にエクスポートすることができる(SIO2にpythonで書かれたExporterスクリプトが付属してます)ので、将来iPhoneでゲームを作ってみたいという人にはうってつけ、というわけです。
blenderの入門ページとしては、藤堂+さんのblenderチュートリアルが素晴らしくまとまっています。ただSIO2のことについて触れていないので、SIO2にガンガンExportして楽しみたい人は、SIO2付属のチュートリアル(Export前のblendファイルも付いてくる)をいじってみるか、Haraさんのブログエントリー
あたりを読むと分かりやすいです。やっていることはSIO2のチュートリアルをちょこっと改造して、自作の3Dモデルに置き換えてみよう、というものです。
で、blener+SIO2にはここには書いていない落とし穴も結構あったります。
以下に書くことは、全てsio2interactive wiki-Blender and SIO2にも書いてあることで、最初からここを読んでいれば特にはまることもないのですが、日本語の文献はあまり無いようだったので、書いておきます。
- blender上で同じ計上のオブジェクトしていくと、Cube.001,Cube.002,Cube.003のように末尾に連番が付いたIDのオブジェクトが自動的に作成されていきますが、これらはblender上だと重複オブジェクトとみなされ、SIO2上で表示されないことがあります。他のIDに手動で変更が必要です。
- テクスチャはUV Mappingしなければ表示されません。UV Mappingという用語は3D素人にはなじみがないのですが、以下のエントリーが詳しいです。
- Blender.jp - フォーラム 日本語のチュートリアルがZIPで落とせる
- Blender 3D: Noob to Pro/UV Map Basics - Wikibooks, collection of open-content textbooks 英語、地球を貼付ける
- BioRUST.com :: Tutorials >> UV Mapping 英語、さいころ的なもの
- スケール値が1.0以外の値が設定されていると衝突判定されない。これはHaraさんのブログにも書いてありますね。物体の大きさを変更する場合は、Object Modeで変更するのではなく、Edit Modeで変更するか、Object Modeで変更した後に、Ctl-A -> scale and rotation to objectでscaleが1.0になるように再調整します。
とまあこんなところです。仮にiPhoneアプリ上で複雑な3Dモデルを作る場合、SIO2でモデルを組み立ててもよいのですが、blender上でpythonのマクロが使えるのでそこでモデリングしてしまうのが結構簡単です。次回はpythonでモデルを組み立てるところをやってみます。
Box2DFlashAS3でブロック崩しゲームを作る
世の中は年が開けて抱負やら振り返りやらでとてもフレッシュな香りがただよっていますが、僕は空気も読まずにとりあえず年末年始の休暇にやってたことをまとめようと思います。
まずはBox2DFlashAS3という物理エンジンのライブラリを使ってブロック崩しゲームを作りました。
やっつけで作った糞ゲーもいいところですが、せっかくなので貼付けておきます。左右でバーを動かすことができます。床に玉をついてもゲームオーバーはありません。
Box2DFlashAS3(以下Box2D)とは
ActionScriptから利用できる物理エンジン、いわゆるニュートン力学が支配する世界を手軽に表現することが出来るライブラリの一つです。これを使うことで重力による落下とか衝突とかを簡単に実装することができます。
物理エンジンには様々なライブラリがあり、たとえばNVIDIAのPHYSXなんかはよく知られています。2Dも3Dもありますが、Box2Dは2Dの物理エンジンの代表的なものです。僕は物理エンジンというもの自体になじみがなかったのですが、Web+DB PRESS vol.54の特集を読んで初めて知りました。
情報源
- Box2D(Box2DFlashAS3)によるActionScript物理シミュレーション | Flash | HapHands
- 特集:Box2DでActionScript物理プログラミング|gihyo.jp … 技術評論社
あたりが詳しいです。僕のブロック崩しは、下の技評のサンプルプログラムを改造しています。
工夫した点など
あまり情報がないのですが、Box2Dの物体(b2Bodyオブジェクト)には作成した物体のIDや名前などのプロパティがないため、後から物体を一意に識別したい場合は、UserDataというカスタムフィールドを参照します。
以下はブロック崩しゲームにおける、ブロックの初期作成処理です。UserDataに、ブロックの連番であるindexプロパティと、ブロックであることを示すtypeプロパティをセットしています。
for (var i:int = 0; i<boxNum; i++) { //中略 var block:b2Body = world.CreateBody(blockBodyDef); var blockObj:Object = new Object(); blockObj.index = i; blockObj.type = "block"; //中略 block.SetUserData(blockObj); //中略 }
以下がUserDataを参照している部分。ContactListenerという衝突検知用のクラスの中で、UserDataのtypeプロパティを参照して、衝突した物体がブロックだったかどうかを判定しています。
public override function Add(point:b2ContactPoint):void { var body1:b2Body = point.shape1.GetBody(); var body2:b2Body = point.shape2.GetBody(); var obj1:Object = body1.GetUserData(); var obj2:Object = body2.GetUserData(); if (obj1 != null && obj1.type == "block") { //中略 } else if (obj2 != null && obj2.type == "block") { //中略 } }
全コード
BlockGame.as
package { import Box2D.Collision.b2AABB; import Box2D.Collision.b2ContactPoint; import Box2D.Collision.Shapes.b2PolygonDef; import Box2D.Collision.Shapes.b2CircleDef; import Box2D.Collision.Shapes.b2MassData; import Box2D.Common.Math.b2Vec2; import Box2D.Dynamics.b2Body; import Box2D.Dynamics.b2BodyDef; import Box2D.Dynamics.b2ContactListener; import Box2D.Dynamics.b2DebugDraw; import Box2D.Dynamics.b2World; import flash.display.Sprite; import flash.events.Event; import flash.events.MouseEvent; import flash.events.KeyboardEvent; import flash.ui.Keyboard; import flash.text.TextField; import flash.text.TextFieldAutoSize; import flash.text.TextFormat; public class BlockGame extends Sprite { private var world:b2World; private var barstate:int; private var bar:b2Body; public var blockArray:Array; public var destroyFlagArray:Array; private var destroyNum:int; private var textField:TextField; private const boxNum:int = 35; public function BlockGame():void { // イベントハンドラを登録する barstate = 0; textField = new TextField(); // テキストフィールドの準備 textField.autoSize = TextFieldAutoSize.LEFT; textField.textColor = 0xFFFFFF; textField.x = 100; textField.y = 220; textField.text = "Click to start."; var f:TextFormat = new TextFormat(); f.font="Arial"; f.size=24; textField.setTextFormat(f); addChild(textField); stage.addEventListener(MouseEvent.CLICK, clickHandler); stage.addEventListener(Event.ENTER_FRAME, enterFrameHandler); stage.addEventListener(KeyboardEvent.KEY_DOWN, keyDownHandler); stage.addEventListener(KeyboardEvent.KEY_UP, keyUpHandler); } private function clickHandler(event:MouseEvent):void { trace("start clickHandler"); destroyNum = 0; textField.text = ""; //////////////////////////////////////// // 物理エンジンのセットアップ var worldAABB:b2AABB = new b2AABB(); worldAABB.lowerBound.Set(-100, -100); worldAABB.upperBound.Set(100, 100); var gravity:b2Vec2 = new b2Vec2(0, 10); world = new b2World(worldAABB, gravity, true); //////////////////////////////////////// // 固定オブジェクトの配置 var floor:b2Body = createStaticBoxShape(2.5, 3.0, 2.3, 0.1); var top:b2Body = createStaticBoxShape(2.5, -0.1, 2.3, 0.1); var l:b2Body = createStaticBoxShape(0.1, 3.0, 0.1, 50.0); var r:b2Body = createStaticBoxShape(4.9, 3.0, 0.1, 50.0); /////////////////////////////////////// // ユーザーが動かすバーを作成する var barBodyDef:b2BodyDef = new b2BodyDef(); barBodyDef.position.Set(2.5, 2.8); var barShapeDef:b2PolygonDef = new b2PolygonDef(); barShapeDef.SetAsBox(0.5, 0.1); barShapeDef.density = 1; barShapeDef.restitution = 0; bar = world.CreateBody(barBodyDef); bar.CreateShape(barShapeDef); bar.SetMassFromShapes(); blockArray = new Array(); destroyFlagArray = new Array(); var bw:int = 7; for (var i:int = 0; i<boxNum; i++) { var x:int = i%bw; var y:int = i/bw; var xd:Number = 0.7 + x * 0.6; var yd:Number = 0.2 + y * 0.2; var blockBodyDef:b2BodyDef = new b2BodyDef(); blockBodyDef.position.Set(xd, yd); var blockShapeDef:b2PolygonDef = new b2PolygonDef(); blockShapeDef.SetAsBox(0.3, 0.1); var block:b2Body = world.CreateBody(blockBodyDef); var blockObj:Object = new Object(); blockObj.index = i; blockObj.type = "block"; blockObj.destroy = false; block.SetUserData(blockObj); block.CreateShape(blockShapeDef); blockArray.push(block); destroyFlagArray.push(false); } //////////////////////////////////////// // ボールの設置 var ballBodyDef:b2BodyDef = new b2BodyDef(); ballBodyDef.position.Set(2.5, 1.0); var ballShapeDef:b2CircleDef = new b2CircleDef(); ballShapeDef.radius = 0.05; ballShapeDef.density = 0.2; // 密度 [kg/m^2] ballShapeDef.restitution = 1; // 反発係数、通常は0〜1 var ballBody:b2Body = world.CreateBody(ballBodyDef); ballBody.CreateShape(ballShapeDef); ballBody.SetMassFromShapes(); ballBody.ApplyImpulse(new b2Vec2(0.01, 0.01), ballBody.GetWorldCenter().Copy()); var contactListener:ContactListener = new ContactListener(); world.SetContactListener(contactListener); //////////////////////////////////////// // 描画設定 var debugDraw:b2DebugDraw = new b2DebugDraw(); debugDraw.m_sprite = this; debugDraw.m_drawScale = 100; // 1mを100ピクセルにする debugDraw.m_fillAlpha = 0.3; // 不透明度 debugDraw.m_lineThickness = 1; // 線の太さ debugDraw.m_drawFlags = b2DebugDraw.e_shapeBit; world.SetDebugDraw(debugDraw); } private function keyDownHandler(event:KeyboardEvent):void { if (event.keyCode == Keyboard.LEFT) { barstate = 1; stage.addEventListener(Event.ENTER_FRAME, enterFrameHandler2); } else if (event.keyCode == Keyboard.RIGHT) { barstate = 2; stage.addEventListener(Event.ENTER_FRAME, enterFrameHandler2); } } private function keyUpHandler(event:KeyboardEvent):void { if (barstate == 1 && event.keyCode == Keyboard.LEFT) { barstate = 0; stage.removeEventListener(Event.ENTER_FRAME, enterFrameHandler2); } else if (barstate == 2 && event.keyCode == Keyboard.RIGHT) { barstate = 0; stage.removeEventListener(Event.ENTER_FRAME, enterFrameHandler2); } } private function enterFrameHandler2(event:Event):void { var c:b2Vec2 = bar.GetWorldCenter().Copy(); var f:b2Vec2; if (barstate == 1) { f = new b2Vec2(-1,0); } else if (barstate == 2) { f = new b2Vec2(1,0); } bar.ApplyForce(f, c); } private function enterFrameHandler(event:Event):void { if (world == null) { return; } if (destroyNum == boxNum) { gemeEnd(); return; } for (var i:int=0; i<boxNum; i++) { if (!destroyFlagArray[i]) { var obj:Object = blockArray[i].GetUserData(); if (obj && obj.destroy) { trace("destory:"+i); world.DestroyBody(blockArray[i]); destroyFlagArray[i] = true; ++destroyNum; } } } world.Step(1/24, 10); } private function gemeEnd():void { textField.text = "Clear! Click to restart."; world = null; } private function createStaticBoxShape(locX:Number, locY:Number, sizeX:Number, sizeY:Number):b2Body { var bbd:b2BodyDef = new b2BodyDef(); bbd.position.Set(locX, locY); var bpd:b2PolygonDef = new b2PolygonDef(); bpd.SetAsBox(sizeX, sizeY); var b:b2Body = world.CreateBody(bbd); b.CreateShape(bpd); return b; } } }
ContactListener.as
package { import Box2D.Collision.b2ContactPoint; import Box2D.Collision.Shapes.b2MassData; import Box2D.Dynamics.b2ContactListener; import Box2D.Dynamics.b2Body; import Box2D.Collision.Shapes.b2PolygonDef; public class ContactListener extends b2ContactListener { public function ContactListener() { } public override function Add(point:b2ContactPoint):void { var body1:b2Body = point.shape1.GetBody(); var body2:b2Body = point.shape2.GetBody(); var obj1:Object = body1.GetUserData(); var obj2:Object = body2.GetUserData(); if (obj1 != null && obj1.type == "block") { trace("this is obj1"); obj1.destroy = true; } else if (obj2 != null && obj2.type == "block") { trace("this is obj2"); obj2.destroy = true; } } } }
その他補足情報
コンパイルするときは、BlockGame.asとContactListener.asを同じディレクトリにおいて、
mxmlc -source-path=path_to_box2d_library BlockGame.as
traceで吐き出したログを参照する方法は
windowsやmacで、flashのtraceログが吐かれる場所 - カサヒラボ
あたりを参照。
次回は3Dの物理エンジンをいじった成果を書く予定です。
heap sortを使った上位ランキング取得プログラム
Managing Gigabytes 4.6章で解説されているソートのプログラムを実装してみた。
検索エンジンなどでN個のデータの中から上位r個を取得したい場合、まずN個のデータからなるmax-heapを構成して、ルート(最大値)から順にr個をヒープから取り除くというアプローチが考えられる。しかしN>>rの場合、r個のデータからなるmin-heapを構成して、残りのN-r個のデータをheapのルート(最小値)と順次比較して、ルートよりも大きい場合はルートと入れ替えて、heapを再構成する、という手順を取った方がより計算量が少なくて済む、という話。
10万件のランダムな数値をスペース区切りで出力するプログラム
#include <iostream> #include <stdlib.h> using namespace std; int main (int argc, char *argv[]) { srand((unsigned int)time(0)); for (int i=0; i<100000; i++) { int tmp = rand()/1000; cout << tmp << " "; } cout << endl; return 1; }
10件のmin-heapを構成して、10万件のデータから上位10件を取得・表示するプログラム
#include <iostream> #include <fstream> #define SWAP(a,b) (a ^= b,b = a ^ b,a ^= b) using namespace std; #define HEAPSIZE 10 #define DATAN 100000 int heapsize = HEAPSIZE; int heap[HEAPSIZE]; int datan = DATAN; void make_heap (int root); int main (int argc, char *argv[]) { srand((unsigned int)time(0)); ifstream input_file("testdata"); int testdata[datan]; for (int i=0; i<datan; i++) { input_file >> testdata[i]; } //copy the first r accumlators into the heap for (int i=0; i<heapsize; i++) { heap[i] = testdata[i]; } //convert array into min-heap for (int i=((heapsize+1)/2); i>0; i--) { make_heap(i-1); } for (int i=heapsize; i<datan; i++) { //testdata[i]; if (heap[0] < testdata[i]) { //discard rheap[0] and insert testdata[i] into min-heap heap[0] = testdata[i]; make_heap(0); } } //display top-r number in the ascending order while (heapsize > 0) { cout << heap[0] << endl; heap[0] = heap[heapsize-1]; make_heap(0); heapsize--; } return 1; } void make_heap (int root) { int left = (root+1)*2-1; int right = left+1; if (left >= heapsize) { return; } if (heap[root] <= heap[left]) { if (right >= heapsize || heap[root] <= heap[right]) { //root is min //do nothing } else { //right is min //printf("swap %d and %d\n", root, right); SWAP(heap[root], heap[right]); make_heap(right); } } else { if (right >= heapsize || heap[left] <= heap[right]) { //left is min //printf("swap %d and %d\n", root, left); SWAP(heap[root], heap[left]); make_heap(left); } else { //right is min //printf("swap %d and %d\n", root, right); SWAP(heap[root], heap[right]); make_heap(right); } } }
考察とかは特になしです。他のアルゴリズムとのベンチ比較とかも時間があればやってみたい。
第9回PRML読書会
土曜日はサイボウズ・ラボで行われた第9回PRML読書会に参加しました。
自分は発表者トップバッターでSVMの基本的なところを説明しました。
参加者の方からもいろいろ指摘をいただきました。
- なぜマージンを最大化するとよいのか?の説明で『まず2値に分類された学習データをガウスカーネルでのParzen推定を適用して入力の分布を推定する。誤分類が最小になる分類平面は、ガウスカーネルの分散を→0の極限において、マージンを最大化する分類平面に一致する』とあるが、なぜ分散を0に近づけるのかがわからない。
- そういうものとして理解するしかない?理論的な説明はまだ分からずです。。
- Randomized Algorythmを適用してSVMの計算を高速化する手法がある。
- ちょっとググってみたところこの辺ですかね。いろいろと制約はるみたいですがO(log n)で二次計画問題の近似解が求まる!
- biasをゼロと仮定して、二次計画問題を高速で解く手法が存在する。
- liblinearなどのライブラリでこの手法が利用されている。
多クラスSVMの発表のときにも参加者の間で議論がありましたが、SVMはやはり2クラス分類器で、多クラスに応用する事例というのは実際にも少ないそうです。(PRML7.2で紹介されているRVMはSVMとは名前は似ているが完全な別モノ!)
個人的には2クラス分類、と言われてもあまり実際的な使い道が思いつかないのですが、世間的にSVMがもてはやされてるのはそれだけ2クラス分類をしたい問題が多いということなのか。今後SVMの事例を見る際は、そういう観点からも考えていきたいと思いました。
あと、8月末の第6回読書会でもニューラルネットワークの説明をしたのでいまさらですが資料をあげておきます。
このときにも冒頭にSVMとの比較がありましたが、今回改めて両方を比較してみると、
- SVMはSparseな解になる
- 訓練データの特性にもよるがマージン境界上の訓練データ集合が少ない場合は予測計算量がすくない。
- SVMは局所解問題がない
- NNでは局所解問題があるので、初期値選択を慎重にする必要がある。
- SVMにはマージン最大化による汎化能力がある。
- NNは過学習の問題がある。
- SVMは基本的に二クラス分類器
- NNはどちらかといえば回帰に向いている。
という感じでしょうかな。多クラス分類とか回帰とかだったらSVM以外の選択肢を使った方がよさそうかな、と改めて思いました。
Rでガウス過程による分類を実装
PRMLの6.4.5〜6.4.6の範囲にあるガウス過程による分類をRで実装してみました。
ソースコード全文はgithubにアップしています。
http://github.com/thorikawa/prml/blob/master/gaussian_process_classify.R
ここでは例として、(1,0),(2,0),(3,0)で1、(0,1),(0,2),(0,3)で0の値を取る訓練集合を用いています。
# Training data x=list(c(1,0),c(2,0), c(3,0), c(0,1), c(0,2), c(0,3)) t=c(1,1,1,0,0,0) training_data_num <- length(x)
この訓練集合とカーネル関数をもとに予測分布を導出しています。
ガウス過程においては、訓練集合から予測分布を決める(ほぼ)唯一の要素はカーネル関数になるということなのは分かるのですが、カーネル関数によってどのように予測分布が変わってくるのかという点は良く理解できていません。
この辺をまずは感覚でつかむために、いくつか代表的なカーネルで予測分布を求めてみました。
ガウスカーネル
で与えられるカーネル。
の値を0.2,0,4と変えてみると、id:n_shuyoさんの考察通り、は訓練データを中心とするガウス分布の尖り具合を表現していることが分かります。
# Gaussian kernel sigma <- 0.2 kernel <- function (x1, x2) { exp(-sum((x1-x2)*(x1-x2))/(2*sigma^2)); } gp(kernel, "gaussian kernel", "sigma=0.2");
# Gaussian kernel sigma <- 0.4 kernel <- function (x1, x2) { exp(-sum((x1-x2)*(x1-x2))/(2*sigma^2)); } gp(kernel, "gaussian kernel", "sigma=0.4");
線形カーネル
で与えられるカーネル。単純な線形分離になります。
# Linear kernel kernel <- function (x1, x2) { sum(x1*x2); } gp(kernel, "linear kernel");
指数カーネル
で与えられるカーネル。
の値を1.0,0.3と変えてみると、ガウス分布と同様、は指数分布の尖り具合を表していると言えそうです。
# Exponential kernel theta <- 1.0 kernel <- function (x1, x2) { exp(-theta*sqrt(sum((x1-x2)*(x1-x2)))); } gp(kernel, "exponential kernel", "theta=1.0");
# Exponential kernel theta <- 0.3 kernel <- function (x1, x2) { exp(-theta*sqrt(sum((x1-x2)*(x1-x2)))); } gp(kernel, "exponential kernel", "theta=0.3");
多項式カーネル
で与えられるカーネル。
M=2,M=3と次数を変えてみると、、、これは、、、どういう特徴を持つのかいまいち不明です。
# Polynomial kernel kernel <- function (x1, x2) { (sum(x1*x2)+1)^2; } gp(kernel, "polynomial kernel", "degree=2");
# Polynomial kernel kernel <- function (x1, x2) { (sum(x1*x2)+1)^3; } gp(kernel, "polynomial kernel", "degree=3");
まとめ
というわけで理解は浅いのですが、様々なカーネル関数でガウス過程による分類を実装してみました。
カーネルの役割は、訓練集合における任意の2つのデータに対する、出力の類似度をどう評価するかを決定することだと理解しています。
ガウスカーネルや指数カーネルは、RBFであり、中心からの距離に対して反比例する形で値が減少していくため、入力値が近ければ近いほどその出力の値も近い値を示し、予測分布は訓練集合を中心とした分かりやすい分布になります。
線形や多項式カーネルはRBFではないので直感的な理解が現時点では出来ていません。カーネル多変量解析を熟読すればこのあたりも理解できるようになるかなあ。
ニューラルネットワークで画像認識
ニューラルネットワークの簡単な関数近似プログラムを先日書いたので、今は画像認識プログラムを書いてますが、ものすごく簡単なバージョンが出来上がったので晒しておきます。
C++で画像解析部分を作って、Rで訓練データの学習、テストデータの判定をしています。
MNISTの手書き文字から学習データ準備
まずは、インプットとなる画像のデータですが、定番のMNISTの手書き数値データを使います。
元々のIDXフォーマットというフォーマットは再利用性が無さそうなので、http://www.cs.toronto.edu/~roweis/data.htmlから既にJPEG化されたものを引っ張ってきてそれを解析します。
OpenCVを使って2値化+CSV出力する簡単なプログラム。
#include "cv.h" #include "highgui.h" #include <stdio.h> #include <stdlib.h> const int kDIGIT_W = 28; const int kDIGIT_H = 28; int main (int argc, char* argv[]) { IplImage *src_img, *dest_img; char* imgfile = argv[1]; int answer = atoi(argv[2]); src_img = cvLoadImage(imgfile, CV_LOAD_IMAGE_GRAYSCALE); dest_img = cvCreateImage(cvGetSize(src_img), IPL_DEPTH_8U, 1); // make src image gray scale cvThreshold(src_img, dest_img, 90, 255, CV_THRESH_BINARY); int width = dest_img->width; int height = dest_img->height; int xCount = width / kDIGIT_W; int yCount = height / kDIGIT_H; for (int x=0; x<xCount; x++) { for (int y=0; y<yCount; y++) { int offsetX = x*kDIGIT_W; int offsetY = y*kDIGIT_H; printf("%d", answer); for (int i=0; i<kDIGIT_W; i++) { for (int j=0; j<kDIGIT_H; j++) { int val = dest_img->imageData[(j+offsetY)*dest_img->widthStep+(i+offsetX)]; printf(",%d", val*-1); } } printf("\n"); } } cvReleaseImage(&src_img); cvReleaseImage(&dest_img); return 0; }
これで、以下のように実行すると画像データが格納されたCSVが出来上がります。
1行が1文字を表わしており、1列目が正解の値(数値)、2列目以降が二値化された画像データです。
$ g++ -g -I/usr/include/opencv make_test_data.cpp -o make_test_data -lcxcore -lcv -lhighgui -lcvaux -lml #訓練データの準備 $ ./make_test_data mnist_train0.jpg 0 > mnist_train_all.txt $ ./make_test_data mnist_train1.jpg 1 >> mnist_train_all.txt $ ./make_test_data mnist_train2.jpg 2 >> mnist_train_all.txt .... $ ./make_test_data mnist_train9.jpg 9 >> mnist_train_all.txt #テストデータの準備 $ ./make_test_data mnist_test0.jpg 0 > mnist_test_all.txt $ ./make_test_data mnist_test0.jpg 1 >> mnist_test_all.txt .... $ ./make_test_data mnist_test0.jpg 9 >> mnist_test_all.txt
学習プログラム
次にこのCSVデータをインプットするニューラルネットワークの学習プログラムを組みます。
ニューラルネットの設定は以下です。
- 基本的な構造:入力が784次元、出力が10次元、隠れユニット層を1つ含む2層ネットワーク(PRMLで言うところの)
- 隠れユニット数:30
- 学習パラメータ:0.05
- 学習方法:オンライン学習。ただし、1つのエポックでは60338件の訓練データ集合の中から1000件のデータをランダム抽出し、それを1000エポック繰り返す。
- 出力ユニットの活性化関数:ソフトマックス関数
- 誤差関数:交差エントロピー誤差関数(ただし、プログラム中は活性に対する微分であるyk-tkのみ使われているため、交差エントロピー関数そのものは出てこない。)
#number of hidden unit s <- 30 #learning rate parameter l <- 0.05 #count of loop MAX_EPOCH <- 1000 #read traning data digit_data <- read.csv("train/mnist_train_all.txt",header=F) ##initialize weight parameter #weight parameter between input and hidden units w1 <- matrix(runif(s*length(digit_data), -1, 1), s, length(digit_data)) #weight parameter between hidden units and output w2 <- matrix(runif(10*(s+1), -1, 1), 10, s+1) #definition of function logsumexp <- function (x) { m <- max(x) m + log(sum(exp(x-m))) } softmax <- function (a) { sapply(a, function (x) { exp(x-logsumexp(a)) }) } neuro_func <- function (input) { a1 <- w1 %*% c(1, input) z1 <- tanh(a1) a2 <- w2 %*% c(1, z1) z2 <- softmax(a2) return(z2) } ##train the neural network for (k in 1:MAX_EPOCH) { #sample 1000 training data sample_index = sample(1:nrow(digit_data), 1000) for (i in sample_index) { tmp <- digit_data[i,] #target data t <- rep(0, 10) t[as.integer(tmp[1]+1)] <- 1 input <- t(tmp[2:length(tmp)]) #forward propagation a1 <- w1 %*% c(1, input) z1 <- tanh(a1) a2 <- w2 %*% c(1, z1) z2 <- softmax(a2) #back propagation d2 <- z2 - t w2 <<- w2 - l * d2 %*% t(c(1, z1)) d1 <- d2 %*% w2[,2:(s+1)] * (1-tanh(t(a1))^2) w1 <<- w1 - l * t(d1) %*% t(c(1, input)) } } save(w1, w2, logsumexp, softmax, neuro_func, file="nnet_image.rdata")
テストプログラム
準備したテストデータを読み込んで、ネットワークの出力のうち、最大の出力を持つユニットをネットワークが判断した数値とみなしています。
#read trained parameter load("nnet_image.rdata") #read traning data digit_data <- read.csv("test/mnist_test_all.txt",header=F) ##main function to train the neural network correct <- 0 wrong <- 0 for (i in 1:nrow(digit_data)) { tmp <- digit_data[i,] # target number of class answer <- tmp[1] t <- rep(0, 10) t[as.integer(answer+1)] <- 1 input <- t(tmp[2:length(tmp)]) # forward propagation z <- neuro_func(input) zt <- which.max(z) - 1 if (zt == answer) { correct <- correct + 1 } else { wrong <- wrong + 1 } } cat(paste("correct case=",correct,", wrong case=", wrong, "\n"))
結果とまとめ
テストデータ集合10153件中9360件正解、ということで正解率は約92.5%でした。
(correct case= 9360 , wrong case= 793 )
自分の手書き文字でも試してみましたが、まあ意地悪いことをしなければ大体正解しているみたいです。
ただ、他の方の結果と比較すると、id:ultraistさんの結果やhttp://yann.lecun.com/exdb/mnist/の例では、2層ネットワークで97%近い正解率まで達しているので、まだまだ。
id:ultraistさんは隠れユニット数が512で、LeCun氏は300とか1000とかなので、僕の隠れユニット数30はちょっと少なすぎたか、という印象。ただ、入力ベクトルが784次元なので、それに対し隠れユニットが1000はいくらなんでも多い、というか一般化できていないんじゃないかそれ?という感じなのだが、どうなんだろう。
とりあえず始めたばかりなので、TODOを列挙。
- 隠れユニット数をいじってみて比較
- 学習パラメータをいじってみて比較
- deskew処理
- オブジェクト抽出処理
- 画素以外の特徴ベクトルを入力にして学習
- 2層ネットワーク以外のやり方と比較
と見えているだけでもまだまだやることてんこ盛りです・・・。地道にやっていきます。
2009/9/13追記
学習時間について
この例ではのべ100万件のデータを学習していますが、全部で丸2〜3日ほどかかりました。
Rでニューラルネットワーク(2変数の関数近似)
で、1変数の関数近似がうまくいったので、調子にのって2変数の関数近似にもチャレンジしてみました。
2変数のsinc関数を、ニューラルネットワークの誤差伝播法を使って近似する(しようとする)ものです。
library(animation) #number of training data N <- 30 #number of hidden unit s <- 5 #learning rate parameter l <- 0.01 #standard deviation of training data sd <- 0.05 #count of loop LOOP <- 10000 #frame interval of animation movie INTERVAL <- 0.1 #total time of animation movie TOTAL <- 20 #initialize weight parameter w1 <- matrix(seq(0.1,length=s*3, by=0.1),s,3) w2 <- matrix(seq(0.1,length=s+1, by=0.1),1,s+1) #generate traning data xt <- seq(-10,10,length=N) yt <- xt f <- function(x, y) { r <- sqrt(x^2+y^2) temp = function (r) { if ( r == 0 ) { return(10) } else { return(10*sin(r)/r) } } sapply(r, temp) } zt <- outer(xt, yt, f) persp(xt,yt,zt,theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA) neuro_func_ <- function (x, y) { a <- w1 %*% c(1, x, y) z <- tanh(a) y <- w2 %*% c(1, z) return(y) } #vectorize neuro_func_ neuro_func <- function (x, y) { mapply(neuro_func_, x, y) } persp(xt, yt, outer(xt, yt, neuro_func), theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA, axes=TRUE, ticktype="detailed") #main function to train the neural network neuro_training <- function () { trigger_count <- as.integer(LOOP / (TOTAL / INTERVAL)) for (k in 1:LOOP) { if (k %% trigger_count == 1) { persp(xt, yt, outer(xt, yt, neuro_func), theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA, axes=TRUE, ticktype="detailed") } for (i in 1:N) { for (j in 1:N) { # forward propagation a <- w1 %*% c(1, xt[i], yt[j]) b <- tanh (a) z <- w2 %*% c(1, b) # back propagation d2 <- z - zt[i,j] w2 <<- w2 - l * d2 * c(1, b) d1 <- d2 %*% (1-tanh(t(a))^2) * t(w2[2:(s+1)]) w1 <<- w1 - l * t(d1) %*% c(1, xt[i], yt[j]) } } } } saveMovie(neuro_training(), interval=INTERVAL, moviename="neural_network_training_movie", movietype="mpg", outdir=getwd(), width=640, height=480) cat("learning end print w1 and w2\n") print(w1) print(w2) persp(xt, yt, outer(xt, yt, neuro_func), theta = 30, phi = 30, expand = 0.5, col = rainbow(50), border=NA, axes=TRUE, ticktype="detailed")
で、結果が以下のムービー。
サンプルデータを10000回ループで読み込ませているんですが、途中から微動だにしなくなってしまいました・・・。
何かパラメータの設定が悪いのか、もっとループを多くしなければいけないのか、まだ原因は分かっていません。。。どなたか知識のある方がこれを読んで頂けたら指摘してもらえると嬉しいです。