2022年9月5日月曜日

PytorchのC++のライブラリlibtorchを使う2022

最近AI囲碁や将棋、たんぱく質の構造推定などあらゆるAIでResNetが使われています。

ResNetって画像分類以外のこともできるようです。

ResNetくらいなら簡単なものならおじさんの個人所有のGPUでもなんとか学習ができそう。

C++でやってみたい・・・・



c++でResNetができるフレームワークを探すと、Pytorchがでてきます。

Pytorchはもともとコア部分がlibtorchにまとまっていてc++から使えるようになっているし、Android版はNewralNetworkAPIなどに対応しており、GPUを使って高速に演算ができるようです。

 最近、GoogleのTensorFlow、マイクロソフトのonnxruntime、FacebookのPytorchなど、いろいろな機械学習フレームワークが続々とONNX対応、c++対応やスマホ対応になってきています。

Pytorchのようなスマホ対応のフレームワークを使えば、フルセットのPytorchがスマホでも動き、GPU付きのPCで学習させてスマホでその結果を利用することもできるようです。

いろいろなOSで動いて、学習データの再利用ができてほんとうに便利ねー。


そこで、本日は最新のPytorch 1.12.1のc++ライブラリlibtorchをビルドして動かしてみたいと思います。


ビルド方法はドキュメントに書かれている通りでとっても簡単

以前の自分のブログに書いてある通りでビルドできます。

http://yomeiotani.blogspot.com/2019/04/windowslibtorch.html


--------------------

cd <pytorch_root>\tools

mkdir build

cd build

python ..\build_libtorch.py

--------------------

上記のように打つだけ。
Windowsでのビルドは、公式にはVisualStudio 2017にしか対応していないようですが、コンパイラのバージョンをチェックしている部分を直せばVisualStudio 2022でもビルドできるようです。
外部依存ライブラリが前回より変わっています。
演算をより高速にするためにいろいろな機能を追加しているみたいですね。

PyTorchのexample/cppフォルダにはc++のサンプルがあります。
このなかに定番の文字認識「MNIST」のlibtorch版、mnist.cppがあります。


とりあえずビルドして動かしてみました。
----------------------------
Training on CPU.
Train Epoch: 1 [59584/60000] Loss: 0.2362
Test set: Average loss: 0.2119 | Accuracy: 0.935
Train Epoch: 2 [59584/60000] Loss: 0.1365
Test set: Average loss: 0.1301 | Accuracy: 0.959
Train Epoch: 3 [59584/60000] Loss: 0.1805
Test set: Average loss: 0.1069 | Accuracy: 0.965
Train Epoch: 4 [59584/60000] Loss: 0.1147
Test set: Average loss: 0.0899 | Accuracy: 0.970
Train Epoch: 5 [59584/60000] Loss: 0.1017
Test set: Average loss: 0.0834 | Accuracy: 0.974
Train Epoch: 6 [59584/60000] Loss: 0.0549
Test set: Average loss: 0.0725 | Accuracy: 0.976
Train Epoch: 7 [59584/60000] Loss: 0.0657
Test set: Average loss: 0.0697 | Accuracy: 0.979
Train Epoch: 8 [59584/60000] Loss: 0.0832
Test set: Average loss: 0.0639 | Accuracy: 0.980
Train Epoch: 9 [59584/60000] Loss: 0.0523
Test set: Average loss: 0.0609 | Accuracy: 0.982
Train Epoch: 10 [59584/60000] Loss: 0.0504
Test set: Average loss: 0.0559 | Accuracy: 0.983
----------------------------

なんかあってそう。
とりあえずlibtorchのカスタムビルドができたから次はResNetでいろいろ遊ぼうっと。
android版のビルドもしてみたいなぁ。