From 64ff027e645208c1a64a6dbeb2d1540da9e718f8 Mon Sep 17 00:00:00 2001 From: Niels Breuker Date: Sun, 2 Oct 2022 12:54:47 +0200 Subject: First PoC for multithreading in plugins. Crashes on reload while another thread is active. No nice way to return to the default lua_State --- src/Bindings/CMakeLists.txt | 1 + src/Bindings/LuaState.cpp | 21 ++ src/Bindings/LuaState.h | 2 + src/Bindings/ManualBindings.cpp | 5 +- src/Bindings/ManualBindings.h | 3 + src/Bindings/ManualBindings_Threading.cpp | 407 ++++++++++++++++++++++++++++++ 6 files changed, 437 insertions(+), 2 deletions(-) create mode 100644 src/Bindings/ManualBindings_Threading.cpp (limited to 'src') diff --git a/src/Bindings/CMakeLists.txt b/src/Bindings/CMakeLists.txt index 3e7101cbf..d5b6e4f24 100644 --- a/src/Bindings/CMakeLists.txt +++ b/src/Bindings/CMakeLists.txt @@ -17,6 +17,7 @@ target_sources( ManualBindings_Network.cpp ManualBindings_RankManager.cpp ManualBindings_World.cpp + ManualBindings_Threading.cpp Plugin.cpp PluginLua.cpp PluginManager.cpp diff --git a/src/Bindings/LuaState.cpp b/src/Bindings/LuaState.cpp index e4c537967..a6ba3397c 100644 --- a/src/Bindings/LuaState.cpp +++ b/src/Bindings/LuaState.cpp @@ -1033,6 +1033,16 @@ void cLuaState::Push(cLuaTCPLink * a_TCPLink) +void cLuaState::Push(std::thread* a_Thread) { + ASSERT(IsValid()); + + tolua_pushusertype(m_LuaState, a_Thread, "std::thread *"); +} + + + + + void cLuaState::Push(cLuaUDPEndpoint * a_UDPEndpoint) { ASSERT(IsValid()); @@ -1099,6 +1109,17 @@ void cLuaState::Push(std::chrono::milliseconds a_Value) + +void cLuaState::Push(std::mutex * a_Mutex) { + ASSERT(IsValid()); + + tolua_pushusertype(m_LuaState, a_Mutex, "std::mutex *"); +} + + + + + void cLuaState::Pop(int a_NumValuesToPop) { ASSERT(IsValid()); diff --git a/src/Bindings/LuaState.h b/src/Bindings/LuaState.h index d579369f0..6434bb625 100644 --- a/src/Bindings/LuaState.h +++ b/src/Bindings/LuaState.h @@ -624,12 +624,14 @@ public: void Push(const cEntity * a_Entity); void Push(cLuaServerHandle * a_ServerHandle); void Push(cLuaTCPLink * a_TCPLink); + void Push(std::thread * a_Thread); void Push(cLuaUDPEndpoint * a_UDPEndpoint); void Push(double a_Value); void Push(int a_Value); void Push(long a_Value); void Push(const UInt32 a_Value); void Push(std::chrono::milliseconds a_time); + void Push(std::mutex * a_Mutex); /** Pops the specified number of values off the top of the Lua stack. */ void Pop(int a_NumValuesToPop = 1); diff --git a/src/Bindings/ManualBindings.cpp b/src/Bindings/ManualBindings.cpp index f5517dc84..ba257e315 100644 --- a/src/Bindings/ManualBindings.cpp +++ b/src/Bindings/ManualBindings.cpp @@ -1447,12 +1447,12 @@ static int tolua_cPluginManager_CallPlugin(lua_State * tolua_S) { return 0; } - if (ThisPlugin->GetName() == PluginName) + /*if (ThisPlugin->GetName() == PluginName) { LOGWARNING("cPluginManager::CallPlugin(): Calling self is not implemented (why would it?)"); L.LogStackTrace(); return 0; - } + }*/ // Call the destination plugin using a plugin callback: int NumReturns = 0; @@ -4768,6 +4768,7 @@ void cManualBindings::Bind(lua_State * tolua_S) BindRankManager(tolua_S); BindWorld(tolua_S); BindBlockArea(tolua_S); + BindThreading(tolua_S); tolua_endmodule(tolua_S); } diff --git a/src/Bindings/ManualBindings.h b/src/Bindings/ManualBindings.h index f0b7cb607..1dc1f873b 100644 --- a/src/Bindings/ManualBindings.h +++ b/src/Bindings/ManualBindings.h @@ -46,6 +46,9 @@ protected: Implemented in ManualBindings_BlockArea.cpp. */ static void BindBlockArea(lua_State * tolua_S); + /** Binds the manually implemented threading API */ + static void BindThreading(lua_State * tolua_S); + public: // Helper functions: diff --git a/src/Bindings/ManualBindings_Threading.cpp b/src/Bindings/ManualBindings_Threading.cpp new file mode 100644 index 000000000..3dde1f988 --- /dev/null +++ b/src/Bindings/ManualBindings_Threading.cpp @@ -0,0 +1,407 @@ + +// ManualBindings_Network.cpp + +// Implements the cNetwork-related API bindings for Lua +// Also implements the cUrlClient bindings + +#include "Globals.h" +#include "ManualBindings.h" +#include "tolua++/include/tolua++.h" +#include "LuaState.h" +#include +extern "C" +{ + #include + #include + #include +} + + + + +using LuaState = lua_State; + + + + +std::string lua_getstring(lua_State* L, int aIndex) +{ + size_t len; + auto s = lua_tolstring(L, aIndex, &len); + return std::string(s, len); +} + + + + + +std::string lua_getstringfield(lua_State* L, int aTableIndex, const char* aFieldName) +{ + lua_getfield(L, aTableIndex, aFieldName); + std::string res = lua_getstring(L, -1); + lua_pop(L, 1); + return res; +} + + + + + +int pushThreadIdOnLuaStack(lua_State* aState, const std::thread::id& aThreadId) +{ + std::stringstream ss; + ss << aThreadId; + auto str = ss.str(); + lua_pushlstring(aState, str.data(), str.size()); + return 1; +} + + + + +/** The name of the thread's Lua object's metatable within the Lua registry. +Every thread object that is pushed to the Lua side has the metatable of this name set to it. */ +static const char* THREAD_METATABLE_NAME = "std::thread *"; +static const char* MUTEX_METATABLE_NAME = "std::mutex *"; + + + + + +/** Dumps the contents of the Lua stack to the specified ostream. */ +static void dumpLuaStack(LuaState* L, std::ostream& aDest) +{ + aDest << "Lua stack contents:" << std::endl; + for (int i = lua_gettop(L); i >= 0; --i) + { + aDest << " " << i << "\t"; + aDest << lua_typename(L, lua_type(L, i)) << "\t"; + aDest << lua_getstring(L, i).c_str() << std::endl; + } + aDest << "(stack dump completed)" << std::endl; +} + + + + + +/** Dumps the call stack to the specified ostream. */ +static void dumpLuaTraceback(LuaState* L, std::ostream& aDest) +{ + + + //luaL_traceback(L, L, "Stack trace: ", 0); + aDest << lua_getstring(L, -1).c_str() << std::endl; + lua_pop(L, 1); + return; +} + + + + + +/** Called by Lua when it encounters an unhandler error in the script file. */ +extern "C" int errorHandler(LuaState * L) +{ + auto err = lua_getstring(L, -1); + LOGWARNING(err); + //std::cerr << "Caught an error: " << err << std::endl; + //dumpLuaStack(L, std::cerr); + //exit(1); + return 0; +} + + + + + + +extern "C" static int mutexNew(LuaState * aState) +{ + cLuaState L(aState); + lua_pushcfunction(aState, errorHandler); + + // Push the (currently empty) mutex object to the Lua side + auto mutexObj = reinterpret_cast(lua_newuserdata(aState, sizeof(std::mutex**))); + L.Push(mutexObj); + //luaL_setmetatable(aState, MUTEX_METATABLE_NAME); + + // Create the new mutex: + *mutexObj = new std::mutex(); + return 1; +} + + + + + + + +/** */ +extern "C" static int mutexLock(LuaState * aState) +{ + auto mutexObj = reinterpret_cast(luaL_checkudata(aState, 1, MUTEX_METATABLE_NAME)); + if (mutexObj == nullptr) + { + luaL_argerror(aState, 0, "'mutex' expected"); + return 0; + } + (*mutexObj)->lock(); + auto numParams = lua_gettop(aState); + luaL_checktype(aState, 2, LUA_TFUNCTION); + lua_pcall(aState, 0, 0, 0); + (*mutexObj)->unlock(); + + return 0; +} + + + + + +static int threadNew(LuaState * aState) +{ + static std::recursive_mutex mtx; + std::scoped_lock lock(mtx); + luaL_checktype(aState, 1, LUA_TFUNCTION); + lua_pushvalue(aState, 1); // Push a copy of the fn to the top of the stack... + auto luaFnRef = luaL_ref(aState, LUA_REGISTRYINDEX); // ... move it to the registry... + auto luaThread = lua_newthread(aState); + lua_pushcfunction(aState, errorHandler); + lua_rawgeti(luaThread, LUA_REGISTRYINDEX, luaFnRef); // ... push it onto the new thread's stack... + luaL_unref(aState, LUA_REGISTRYINDEX, luaFnRef); // ... and remove it from the registry + + // Push the (currently empty) thread object to the Lua side + auto threadObj = reinterpret_cast(lua_newuserdata(aState, sizeof(std::thread**))); + //luaL_setmetatable(aState, THREAD_METATABLE_NAME); + + // Start the new thread: + *threadObj = new std::thread( + [luaThread, luaFnRef]() + { + auto numParams = lua_gettop(luaThread) - 1; + lua_call(luaThread, numParams, LUA_MULTRET); + //if (status == LUA_OK) + } + ); + cLuaState L(aState); + L.Push(threadObj); + return 1; +} + + + + + +/** Provides the thread.sleep() function. +Parameter: the number of seconds to sleep for (floating-point number). */ +extern "C" static int threadSleep(LuaState * aState) +{ + auto seconds = luaL_checknumber(aState, 1); + std::this_thread::sleep_for(std::chrono::milliseconds(static_cast(seconds * 1000))); + return 0; +} + + + + + +/** Implements the thread:join() function. +Joins the specified thread. +Errors if asked to join the current thread. */ +extern "C" static int threadObjJoin(LuaState * aState) +{ + auto threadObj = reinterpret_cast(luaL_checkudata(aState, 1, THREAD_METATABLE_NAME)); + if (threadObj == nullptr) + { + luaL_argerror(aState, 0, "`thread' expected"); + return 0; + } + if (*threadObj == nullptr) + { + luaL_argerror(aState, 0, "thread already joined"); + return 0; + } + if ((*threadObj)->get_id() == std::this_thread::get_id()) + { + luaL_argerror(aState, 0, "`thread' must not be the current thread"); + return 0; + } + (*threadObj)->join(); + *threadObj = nullptr; + return 0; +} + + + + + +/** Implements the thread:id() function. +Returns the thread's ID, as an implementation-dependent detail. +The ID is guaranteed to be unique within a single process at any single time moment (but not within multiple time moments). */ +extern "C" static int threadObjID(LuaState * aState) +{ + auto threadObj = reinterpret_cast(luaL_checkudata(aState, 1, THREAD_METATABLE_NAME)); + if (threadObj == nullptr) + { + luaL_argerror(aState, 0, "`thread' expected"); + return 0; + } + if (*threadObj == nullptr) + { + luaL_argerror(aState, 0, "thread already joined"); + return 0; + } + if ((*threadObj)->get_id() == std::this_thread::get_id()) + { + luaL_argerror(aState, 0, "`thread' must not be the current thread"); + return 0; + } + return pushThreadIdOnLuaStack(aState, (*threadObj)->get_id()); +} + + + + + +/** Implements the thread.currentid() function. +Returns the current thread's ID. This also works on the main thread. */ +extern "C" static int threadCurrentID(LuaState * aState) +{ + return pushThreadIdOnLuaStack(aState, std::this_thread::get_id()); +} + + + + + +/** Called when the Lua side GC's the thread object. +Joins the thread, if not already joined. */ +extern "C" static int threadObjGc(LuaState * aState) +{ + auto threadObj = reinterpret_cast(luaL_checkudata(aState, 1, THREAD_METATABLE_NAME)); + // We shouldn't get an invalid thread object, but let's check nevertheless: + if (threadObj == nullptr) + { + luaL_argerror(aState, 0, "`thread' expected"); + return 0; + } + if (*threadObj == nullptr) + { + return 0; + } + if ((*threadObj)->get_id() == std::this_thread::get_id()) + { + // Current thread is GC-ing self? No idea if that is allowed to happen, but we don't care; just don't join + return 0; + } + (*threadObj)->join(); + *threadObj = nullptr; + return 0; +} + + + + + +/** The functions in the thread library. */ +static const luaL_Reg threadFuncs[] = +{ + {"new", &threadNew}, + {"sleep", &threadSleep}, + {"currentid", &threadCurrentID}, + {NULL, NULL} +}; + + + + + +/** The functions of the thread object. */ +static const luaL_Reg threadObjFuncs[] = +{ + {"join", &threadObjJoin}, + {"id", &threadObjID}, + {"__gc", &threadObjGc}, + {NULL, NULL} +}; + + + + + +/***/ +static const luaL_Reg mutexFuncs[] = +{ + {"new", &mutexNew}, + {NULL,NULL} +}; + + + + + +/** */ +static const luaL_Reg mutexObjFuncs[] = +{ + {"lock", &mutexLock}, + {NULL, NULL} +}; + + + + + +/** Registers the thread library into the Lua VM. */ +extern "C" static int luaopen_thread(LuaState * aState) +{ + //luaL_newlib(aState, threadFuncs); + + //// Register the metatable for std::thread objects: + //luaL_newmetatable(aState, THREAD_METATABLE_NAME); + //lua_pushvalue(aState, -1); + //lua_setfield(aState, -2, "__index"); // metatable.__index = metatable + //luaL_setfuncs(aState, threadObjFuncs, 0); // Add the object functions to the table + //lua_pop(aState, 1); // pop the new metatable + + return 1; +} + + + + + +/**Registers the mutex library into the Lua VM. */ +extern "C" static int luaopen_mutex(LuaState * aState) +{ + //luaL_newlib(aState, mutexFuncs); + + //// Register the metatable for std::thread objects: + //luaL_newmetatable(aState, MUTEX_METATABLE_NAME); + //lua_pushvalue(aState, -1); + //lua_setfield(aState, -2, "__index"); // metatable.__index = metatable + //luaL_setfuncs(aState, mutexObjFuncs, 0); // Add the object functions to the table + //lua_pop(aState, 1); // pop the new metatable + + return 1; +} + + + + + +void cManualBindings::BindThreading(lua_State * tolua_S) +{ + tolua_usertype(tolua_S, "cThread"); + tolua_cclass(tolua_S, "cThread", "cThread", "", nullptr); + tolua_beginmodule(tolua_S, nullptr); + tolua_beginmodule(tolua_S, "cThread"); + tolua_function(tolua_S, "new", threadNew); + tolua_function(tolua_S, "sleep", threadSleep); + tolua_function(tolua_S, "currentid", threadCurrentID); + tolua_function(tolua_S, "join", threadObjJoin); + tolua_function(tolua_S, "id", threadObjID); + /*tolua_function(tolua_S, threadObjFuncs)*/ + tolua_endmodule(tolua_S); + tolua_endmodule(tolua_S); +} -- cgit v1.2.3