2020年1月8日水曜日

KaggleのTitanicの問題をやってみた。

おじさんこんなサイトを見つけました。
https://nehori.com/nikki/2019/12/09/kaggle/
どっかの保田さんって人が書いてんのね。
あらー世の中にはこんなかわいそうな人がいるようです。


このうさんくさそうなサイトの話では、どうもkaggleというとこの問題を解くと、なんと3億円がもらえるらしい。

これは凄い。
ということで今日は上記のサイトにもある、kaggleの入門編のタイタニックの生存者を当てる機会学習をやってみたいとおもいます。

と言っても、これは単なる〇×予想。
なんか問題を見る限りデータをそのまま、svmに入れれば良いんじゃね?
入門編だし、これならおじさんもサクサクできる!

でも、おじさんのこのブログ、c/c++縛りなのです。
なので、今日はこれをc++でといてみたいと思います。
以前c++のsvmの記事をこのブログに書いたし。

戦略として、とりあえずC言語でsvmlight用にデータを整形してsvmlightに食わしてみます。
実装してわかったんですが、データの欠損があるし、正規化とか特徴量抽出とかが意外とめんどくさい。

最初、一発目の手抜き実装で実験をしたら、実行時間が3分くらいかかり、さらに7割くらい正解しているお手本の結果より、さらに正解率が66%違うらしい。
これってひょっとして全然違うのね。
なんという屈辱。あちゃー。

この問題本当によくできているね。
Gooogleとかのプログラミングコンテストと同じで、バカ正直に問題を解こうとすると非常に時間がかかるのね。
おじさん、ひっかかってしまった。

まぁ機械学習って、まず入力の引数が多いと指数関数的に計算時間がかかるので、デバッグを行いやすくするためにもとにかく次元を減らします。
11次元ある入力を重要そうな6次元に絞って計算してみました。
ここでC/C++の威力がでてくるのです。
学習の計算時間が0.2秒くらいになります。

これで計算が早くなったので、次に計算結果が正確になるようにいろいろ試していきます。
誰でも気付くのは性別がけっこう重要。でもこれはもうすでに入力している・・・。
ということでネットでいろいろ検索。

まず、年齢を正規化しないといけないらしい。
ということで年齢を正規化。
だいぶお手本と近くなってきた。
これ、運賃も正規化しないといけないんじゃない?

さらにネットを見る限り、名前に「Mr.」がついてない紳士でないおっさんに「悪い奴だ」フラグを立てると認識率が上がるらしい。
おじさんの周りには苗字に「き」「む」「ら」や「や」「じ」「ま」がついている人は大体極悪人です。
組織がでかくなるとどこにでもいるんだよねぇ、どさくさにまみれてこっそりと悪いことをして、自分だけ生き残ろうとする悪人のおっさん。

ビッグデータってこういう悪いことをするおっさんがいることまでわかるんですね。
おじさん、こういう勝手なゴシップ的な根拠のない人のうわさ的な法則を見つけるのが大好きです。
ということで、悪いおっさんの次元を追加して7次元で実装してみました。

-------------------------
........
OK  1.000000  0.831034
OK  1.000000  1.000158
OK  1.000000  1.094007
OK  1.000000  1.008878
OK  -1.000000  -1.119815
OK  1.000000  1.363857
OK  -1.000000  -1.164235
OK  -1.000000  -1.119815
OK  -1.000000  -0.351570
ct=418  ok=403  96.411483%
-------------------------

おー、やっとお手本と96%一致。
まだ4次元しか正しく入力を合わせてないのに、なんか結果がよすぎる。
なんか正解率の計算間違ってんのかなぁ?

調べて分かったことは、お手本だと思っていたデータはとりあえず男が全員死亡、女が成員生存のダミーデータらしい。
でもとりあえずこれでやっと計算にはバグがなさそうなことがわかりました。

この後は、自分の回答とお手本のどちらが正しいかは回答を実際にコミットしてみないとわからいです。
ここから、何度もWebサイトに結果のコミットを繰り返さないと、正しい答えがわからいのですが、時間がないので今回はここまで。

まだ使っていない残りの次元が7次元あるので、これらに対して入力を正しく行うと正解率がもっと上がると思うんですが、100%正解のデータがダウンロードできないので実験をやるのがめんどくさい。

なんかこの課題は認識率80%を超えると挑戦者の上位3%に入ってすごいらしいです。
僕の回答は良い方向にくるっているのでしょうか?それとも悪い方向にくるってるのでしょうか?

今日はKaggleはPythonやKerasをインストールしなくてもC/C++言語だけでもできるよという話でした。

------------------------
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

static char *strtokrev(char *s1, int s2) {
static char *str = 0;
int f = 0;

if (s1)str = s1;
else s1 = str;

if (!s1)return NULL;

while (1) {
if (!*str) {
str = 0;
return(s1);
}
if (*str == '\r' || *str == '\n') {
*str = 0;
str = 0;
return(s1);
}
if (*str == '\"' && f == 0) {
str++;
s1 = str;
f = 1;
continue;
}
if (*str == '\"' && f == 1 && *(str + 1) == '\"') {
str += 2;
continue;
}
if (*str == '\"' && f == 1) {
*str++ = 0;
f = 0;
continue;
}
if (*str == s2 && f == 0) {
*str++ = 0;
return(s1);
}
str++;
}
}


struct person{
int id;
int survive;
int cls;
char name[64];
int sex;
int age;
double s_age;
int sib;
int per;
int f_mr;
double fare;
};

struct person{
int id;
int survive;
int cls;
char name[64];
int sex;
int age;
double s_age;
int sib;
int per;
int f_mr;
double fare;
};

struct person pli[1024];
struct person pli2[1024];

void store(struct person* plist,int pos, int n, const char* str)
{
char buf[256];
switch (n)
{
case 0:
sscanf(str, "%d", &(plist[pos].id));
break;
case 1:
sscanf(str, "%d", &(plist[pos].survive));
if (plist[pos].survive == 0)plist[pos].survive = -1;
break;
case 2:
sscanf(str, "%d", &(plist[pos].cls));
break;
case 3:
if (strlen(str) < 64 - 1)strcpy(plist[pos].name, str);
if (strstr(plist[pos].name, "Mr."))plist[pos].f_mr = 1;
else plist[pos].f_mr = -1;
break;
case 4:
buf[0] = 0;
plist[pos].sex = -1;
if (strlen(str) < 64 - 1)strcpy(buf, str);
if(strstr(buf,"female"))plist[pos].sex = 1;
else if (strstr(buf, "male"))plist[pos].sex = -1;
break;
case 5:
plist[pos].age = -1;
sscanf(str, "%d", &(plist[pos].age));
break;
case 6:
sscanf(str, "%d", &(plist[pos].sib));
break;
case 7:
sscanf(str, "%d", &(plist[pos].per));
break;
case 9:
sscanf(str, "%lf", &(plist[pos].fare));
break;
default:
break;
}
}
int read_plist(struct person* plist,const char* fn,int ff)
{
FILE *fp=NULL;
char buf[256];
int i, f, id, re, ct = 0;
char *p,*q;

int ct2=0;
static double ave,dmax = 0, dmin = 1000,sub=0;

int fct2=0;
static double fave, fdmax = 0, fdmin = 1000, fsub = 0;

fp = fopen(fn,"rb");
if (fp == NULL)goto end;
fgets(buf, sizeof(buf), fp);
while (1) {
buf[0] = 0;
fgets(buf, sizeof(buf), fp);
if (buf[0] == 0)break;
for (i = 0; i < 14; i++) {
p = buf;
if (i != 0)p = NULL;
if (ff && i == 1)continue;
q = strtokrev(p, ',');
if (q == NULL)q = "";
store(plist,ct,i, q);
}
ct++;
}
if (ff == 0) {
for (i = 0; i < ct; i++) {
if (plist[i].age == -1)continue;
ct2++;
sub += plist[i].age;
if (plist[i].age > dmax)dmax = plist[i].age;
if (plist[i].age < dmin)dmin = plist[i].age;
}
ave = sub; ave /= ct2;
for (i = 0; i < ct; i++) {
if (plist[i].fare == 0)continue;
fct2++;
fsub += plist[i].fare;
if (plist[i].fare > fdmax)fdmax = plist[i].fare;
if (plist[i].fare < fdmin)fdmin = plist[i].fare;
}
fave = fsub; fave /= fct2;
}
//printf("ave=%f\n",ave);
//printf("fave=%f\n",fave);
for (i = 0; i < ct; i++) {
if (plist[i].age == -1)plist[i].s_age = ave;
else plist[i].s_age = plist[i].age;

plist[i].s_age = (plist[i].s_age - dmin)*2/dmax-1;
plist[i].fare = (plist[i].fare - fdmin) * 2 / fdmax - 1;
}

end:
if(fp)fclose(fp);
return 0;
}

void print_plist(struct person* plist,FILE* fd)
{
int i;
if (fd == NULL)return;
for (i = 0; i < 1024; i++) {
if (plist[i].id == 0)break;
fprintf(fd,"%d 1:%d 2:%d 3:%d 4:%f 5:%d 6:%d 7:%f\n",
plist[i].survive,plist[i].cls,plist[i].f_mr,plist[i].sex,plist[i].s_age,
plist[i].sib,plist[i].per,plist[i].fare);
}
}

void comp(const char* fn1,const char* fn2)
{
FILE *f1 = NULL, *f2 = NULL;
char buf1[256];
char buf2[256];
int ct = 0;
int ok = 0;
double d1, d2;
int id, f;
double per;
f1 = fopen(fn1, "rt");
f2 = fopen(fn2, "rt");
if (f1 == NULL || f2 == NULL)goto err;

fgets(buf1, sizeof(buf1), f1);
while (1) {
buf1[0] = 0;
buf2[0] = 0;
d1 = 0;
d2 = 0;
f = 0;
fgets(buf1, sizeof(buf1), f1);
fgets(buf2, sizeof(buf2), f2);
if (buf1[0] == 0 || buf2[0] == 0)break;
ct++;
sscanf(buf1, "%d,%lf", &id, &d1);
sscanf(buf2, "%lf", &d2);
if (d1 == 0)d1 = -1;

if (d1 > 0 && d2 > 0)f = 1;
if (d1 < 0 && d2 < 0)f = 1;
if (f)ok++;
printf("%s  %lf  %lf \n", f ? "OK" : "NG", d1, d2);
}
per = ok * 100;
per /= ct;
printf("ct=%d  ok=%d  %f%%\n", ct, ok, per);
err:
if (f1)fclose(f1);
if (f2)fclose(f2);
}

int main()
{
FILE* fp;
read_plist(pli,"train.csv",0);
read_plist(pli2,"test.csv",1);
fp = fopen("yomei_train.csv","wb");
print_plist(pli,fp);
if (fp)fclose(fp);
fp = fopen("yomei_test.csv", "wb");
print_plist(pli2, fp);
if (fp)fclose(fp);
//printf("----------pli\n");
//print_plist(pli, stdout);
//printf("----------pli2\n");
//print_plist(pli2, stdout);
system("svmlight_learn yomei_train.csv yomei_model.txt");
system("svmlight_classfy yomei_test.csv yomei_model.txt yomei_result.csv");
comp("gender_submission.csv", "yomei_result.csv");
return 0;
}
-------------------------


0 件のコメント:

コメントを投稿