TorchはLuaというスクリプト言語で書かれています。
スクリプト言語ですが、中間言語にコンパイルして高速に動かすこともできます。
Luaのソースコードを見て分かったのですが、とてもコンパクト。
さらに、Lua言語は簡単にC言語の関数やDLLを呼び出せるので、簡単にいろんなライブラリと連携させることができます。
動的にDLLを呼び出せて、オブジェクトの演算子の再定義ができて、簡単にC言語連携ができるので、たくさんのソースをビルドしないといけない機械学習に適しているんだと思います。
TorchのプラグインのDLLってどうやって作るのかを調べたら、Lua用のプラグインDLLをつくればそのままTorchから使えることがわかったので、本日はLua用のプラグインを作ってみます。
luadllsample.h
--------------------------------------------------
#ifndef DF_LUADLL_SAMPLE_H_
#define DF_LUADLL_SAMPLE_H_
#include "lua.h"
#include "lualib.h"
#include "lauxlib.h"
#include "mylua_base.h"
#ifdef _WIN32
#ifdef LUADLLSAMPLE_EXPORTS
#define LUADLLSAMPLE_API __declspec(dllexport)
#else
#define LUADLLSAMPLE_API __declspec(dllimport)
#endif
#else
#define LUADLLSAMPLE_API
#endif
#ifdef __cplusplus
extern "C"
{
#endif
LUADLLSAMPLE_API int func_sample(lua_State* l);
#ifdef __cplusplus
}
#endif
#endif
--------------------------------------------------
luadllsample.cpp
--------------------------------------------------
#include <stdio.h>
#include <stdlib.h>
#include "lua.h"
#include "lualib.h"
#include "lauxlib.h"
#include "mylua_base.h"
#include "luadllsample.h"
LUADLLSAMPLE_API int func_sample(lua_State* l)
{
printf("kita!!!!!\n");
return 123;
}
--------------------------------------------------
test.lua
--------------------------------------------------
f=package.loadlib("luadll_sample.dll","func_sample")
f()
--------------------------------------------------
こんだけでプラグインができるなんて。
とても簡単にプラグインを作れます。
でもこれだとビルド時にliblua.dllがいるので、ビルド時にliblua.dllがいらない版も作ってみました。
これで、ばんばんtorchのプラグインが作れる。
mylua_base.h
--------------------------------------------------
#ifndef MY_LUA_BASE_H_
#define MY_LUA_BASE_H_
#include "lua.h"
#include "lualib.h"
#include "lauxlib.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
extern lua_State *(*p_luaL_newstate) (void);
extern void(*p_luaL_openlibs)(lua_State *L);
extern int(*p_luaL_loadfilex) (lua_State *L, const char *filename,
const char *mode);
extern void(*p_lua_close)(lua_State *L);
extern void(*p_lua_pushcclosure)(lua_State *L, lua_CFunction fn, int n);
extern void(*p_lua_setglobal)(lua_State *L, const char *name);
extern int(*p_lua_pcallk)(lua_State *L, int nargs, int nresults, int errfunc,
lua_KContext ctx, lua_KFunction k);
#ifdef __cplusplus
}
#endif /* __cplusplus */
#define luaL_newstate p_luaL_newstate
#define luaL_openlibs p_luaL_openlibs
#define luaL_loadfilex p_luaL_loadfilex
#define lua_close p_lua_close
#define lua_pushcclosure p_lua_pushcclosure
#define lua_setglobal p_lua_setglobal
#define lua_pcallk p_lua_pcallk
#endif
--------------------------------------------------
mylua_base.cpp
--------------------------------------------------
#include <stdio.h>
#include <stdlib.h>
#include "mylua_base_min.h"
#include "dll_client.h"
#define LUA_SO_NAME "liblua"
// api function
lua_State *(*p_luaL_newstate) (void) = NULL;
void(*p_luaL_openlibs)(lua_State *L) = NULL;
void(*p_lua_close)(lua_State *L) = NULL;
int(*p_luaL_loadfilex) (lua_State *L, const char *filename,
const char *mode) = NULL;
void(*p_lua_pushcclosure)(lua_State *L, lua_CFunction fn, int n) = NULL;
void(*p_lua_setglobal)(lua_State *L, const char *name) = NULL;
int(*p_lua_pcallk)(lua_State *L, int nargs, int nresults, int errfunc,
lua_KContext ctx, lua_KFunction k) = NULL;
// init
class myluabase_init {
public:
myluabase_init();
virtual ~myluabase_init();
};
static myluabase_init mi;
static void* osh = NULL;
myluabase_init::myluabase_init()
{
osh = dll_load(DLL_NAME(LUA_SO_NAME));
if (osh == NULL)printf("liblua dll_load error\n");
p_luaL_newstate = (lua_State *(*) (void))dll_access(osh, "luaL_newstate");
if (p_luaL_newstate == NULL)printf("luaL_newstate dll_access error\n");
p_luaL_openlibs = (void(*)(lua_State *))dll_access(osh, "luaL_openlibs");
if (p_luaL_openlibs == NULL)printf("luaL_openlibs dll_access error\n");
p_lua_close = (void(*)(lua_State *))dll_access(osh, "lua_close");
if (p_lua_close == NULL)printf("lua_close dll_access error\n");
p_luaL_loadfilex = (int(*) (lua_State *, const char *,const char *))dll_access(osh, "luaL_loadfilex");
if (p_luaL_loadfilex == NULL)printf("luaL_loadfilex dll_access error\n");
p_lua_pushcclosure = (void(*)(lua_State *, lua_CFunction, int))dll_access(osh, "lua_pushcclosure");
if (p_lua_pushcclosure == NULL)printf("lua_pushcclosure dll_access error\n");
p_lua_setglobal = (void(*)(lua_State *, const char *))dll_access(osh, "lua_setglobal");
if (p_lua_setglobal == NULL)printf("lua_setglobal dll_access error\n");
p_lua_pcallk = (int(*)(lua_State *, int, int, int,lua_KContext,lua_KFunction))dll_access(osh, "lua_pcallk");
if (p_lua_pcallk == NULL)printf("lua_pcallk dll_access error\n");
}
myluabase_init::~myluabase_init()
{
if (osh) {
dll_close(osh);
}
osh = NULL;
}
--------------------------------------------------
dll_client.h
--------------------------------------------------
#ifndef _DLL_CLIENT_H_
#define _DLL_CLIENT_H_
#define _DLL_CLIENT_H_VER "20130701"
#define LIB_OK 0
#define LIB_ERR_INVALID_REQUEST (-1)
#define LIB_ERR_INVALID_PARAMETER (-2)
#define LIB_ERR_NOSERVICE (-3)
#define LIB_ERR_NOREQUEST (-4)
#if defined(_WIN32) && !defined(__GNUC__)
#define LIBEXT ".dll"
#else
#define LIBEXT ".so"
#endif
#define DLL_NAME(a) a LIBEXT
#ifdef __cplusplus
extern "C"{
#endif
void *dll_load(char *dllname);
void *dll_access(void *pDllmod, char *funcname);
void dll_close(void *pDllmod);
#ifdef __cplusplus
}
#endif
#endif
--------------------------------------------------
dll_client.c
--------------------------------------------------
#include <stdio.h>
#ifdef WIN32 /* For Windows */
#include <windows.h>
#endif
#include "dll_client.h"
#ifdef ERR_PRINT
#define EPRINT(a) fprintf(stderr,a)
#define EPRINT2(a,b) fprintf(stderr,a,b)
#else
#define EPRINT(a) do; while(0)
#define EPRINT2(a,b) do; while(0)
#endif
void *dll_load(char *dllname)
{
void *pFunclib=NULL;
if(dllname==NULL) {
EPRINT("### dllname is NULL\n");
return NULL;
}
#ifdef _WIN32
pFunclib = LoadLibraryA(dllname);
#endif
#if defined(unix) || defined(__APPLE__)
pFunclib = dlopen(dllname, RTLD_LAZY);
#endif
if(!pFunclib) {
EPRINT2("### cannot dll_load >%s<\n",dllname);
}
return pFunclib;
}
void *dll_access(void *pDllmod, char *funcname)
{
void *pDll_func=NULL;
if(pDllmod==NULL) {
EPRINT("### dll_access handle is NULL\n");
return NULL;
}
if(funcname==NULL) {
EPRINT("### dll_access funcname is NULL\n");
return NULL;
}
#ifdef _WIN32
pDll_func = GetProcAddress((HMODULE)pDllmod, funcname);
#endif
#if defined(unix) || defined(__APPLE__)
pDll_func = dlsym(pDllmod, funcname);
#endif
if(!pDll_func) {
EPRINT2("### cannot dll_access >%s<\n",funcname);
}
return pDll_func;
}
void dll_close(void *pDllmod)
{
if(pDllmod==NULL)return;
#ifdef _WIN32
FreeLibrary(pDllmod);
#endif
#if defined(unix) || defined(__APPLE__)
dlclose(pDllmod);
#endif
return;
}
--------------------------------------------------
0 件のコメント:
コメントを投稿