2017年6月28日水曜日

強化学習してみた。

昨日書いた、便利なtiny-dnnとstbライブラリ。さらにstbの動画版の簡単に動画ファイルを出力できるjo_mpeg。
これらのライブラリと強化学習と組み合わせれば、CaffeとかKerasとかを使わないでDQN(深層Qネットワーク)とかできんでない?

と思ったのですが、おじさん仕事が忙しくて、実装や学習させる根性がないので、今日はまずC++で普通の強化学習をやるところまで。

強化学習ってなんなのかというと、ここに理論とソースコードが書いてあります。
http://kivantium.hateblo.jp/entry/2015/09/29/181954

すごいねー。おじさんが大学生だった時代はこんな便利なWebページなかったのね。
深くない普通の強化学習はとってもわかりやすいですね。
おじさんのような浅い知識でも大丈夫。

いつものようにここのソースをそのままとってきて、ちょっと改良。
stbで途中経過と結果の画像を作ります。

おじさん意地でもC++でやるんです。
なんか同じ結果になったぞ。




これに深層学習とかDNNとかを組み合わせるとDQNになるらしい。
ほんとかよ。
ということで、次はC++でtiny-DQNとかやりたいんだけどなー。
そんなに暇ないしなー。誰かライブラリ作ってくれないかなぁ。



-----------------------------
#include <iostream>
#include <cmath>
#include <algorithm>
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"

using namespace std;

double V[101]; // 状態価値関数
double pi[101];   // 方策
int ct = 1;

void write_graph(const char* name,int n,double* d)
{
int x, y,my;
unsigned char img[100 * 100*4];
char fn[256];
sprintf(fn, "%s%d.png",name, n);
memset(img, 255, sizeof(img));

for (x = 0; x < 100; x++){
my = d[x]*100;
if (my > 100)my = 100;
for (y = 0; y < my; y++){
img[(100 - 1 - y) * 100 * 4 + x * 4] = 0;
img[(100 - 1 - y) * 100 * 4 + x * 4+1] = 0;
img[(100 - 1 - y) * 100 * 4 + x * 4+2] = 255;
img[(100 - 1 - y) * 100 * 4 + x * 4+3] = 255;
}
}
stbi_write_png(fn, 100, 100, 4, img, 4 * 100);
}


int main(void){
const double p = 0.4; //表が出る確率

// 状態価値関数の初期化
for (int s = 0; s<100; ++s) V[s] = 0;
V[100] = 1.0;

const double theta = 1e-5; // ループ終了のしきい値
double delta = 1.0; // 最大変更量

// 状態価値関数の更新
while (delta >= theta){
delta = 0.0;
for (int s = 1; s<100; ++s){
double V_old = V[s];
double cand = 0.0;
// 可能な掛け金ごとに勝率を調べる
for (int bet = 1; bet <= min(s, 100 - s); ++bet){
double tmp = p*V[s + bet] + (1.0 - p)*V[s - bet];
cand = max(tmp, cand);
}
V[s] = cand;
delta = max(delta, abs(V_old - V[s])); // 変更量のチェック
}
//状態価値関数の表示
//for (int i = 0; i<101; i++){
// cout << V[i] << ", ";
//}
//cout << endl << endl;
write_graph("v",ct,V);
ct++;
}

// 最適方策の更新
double threshold = 1e-5;
for (int s = 1; s<100; ++s){
double cand = 0;
double tmp;
for (int bet = 1; bet <= min(s, 100 - s); ++bet){
tmp = p*V[s + bet] + (1 - p)*V[s - bet];
if (tmp > cand + threshold){
cand = tmp;
//pi[s] = bet
pi[s] = (double)bet / 100;
}
}
}
// 最適方策の表示
//for (int i = 1; i<100; i++){
// cout << pi[i] << ", ";
//}
//cout << endl << endl;
write_graph("s",0,pi);
}


0 件のコメント:

コメントを投稿