From 4b068b27ca854ce90beefb1b57ef2380edd1b8ff Mon Sep 17 00:00:00 2001 From: Petr Mikheev Date: Tue, 12 Jan 2021 23:17:48 +0100 Subject: [PATCH] Add components/lua/luastate and components/lua/utilpackage --- apps/openmw_test_suite/CMakeLists.txt | 5 +- apps/openmw_test_suite/lua/test_lua.cpp | 167 +++++++++++++++++ .../lua/test_utilpackage.cpp | 80 +++++++++ apps/openmw_test_suite/lua/testing_util.hpp | 59 ++++++ components/CMakeLists.txt | 4 + components/lua/luastate.cpp | 169 ++++++++++++++++++ components/lua/luastate.hpp | 107 +++++++++++ components/lua/utilpackage.cpp | 98 ++++++++++ components/lua/utilpackage.hpp | 13 ++ 9 files changed, 701 insertions(+), 1 deletion(-) create mode 100644 apps/openmw_test_suite/lua/test_lua.cpp create mode 100644 apps/openmw_test_suite/lua/test_utilpackage.cpp create mode 100644 apps/openmw_test_suite/lua/testing_util.hpp create mode 100644 components/lua/luastate.cpp create mode 100644 components/lua/luastate.hpp create mode 100644 components/lua/utilpackage.cpp create mode 100644 components/lua/utilpackage.hpp diff --git a/apps/openmw_test_suite/CMakeLists.txt b/apps/openmw_test_suite/CMakeLists.txt index d78cb69553..3f82df1a52 100644 --- a/apps/openmw_test_suite/CMakeLists.txt +++ b/apps/openmw_test_suite/CMakeLists.txt @@ -15,6 +15,9 @@ if (GTEST_FOUND AND GMOCK_FOUND) esm/test_fixed_string.cpp esm/variant.cpp + lua/test_lua.cpp + lua/test_utilpackage.cpp + misc/test_stringops.cpp misc/test_endianness.cpp @@ -39,7 +42,7 @@ if (GTEST_FOUND AND GMOCK_FOUND) openmw_add_executable(openmw_test_suite openmw_test_suite.cpp ${UNITTEST_SRC_FILES}) - target_link_libraries(openmw_test_suite ${GMOCK_LIBRARIES} components) + target_link_libraries(openmw_test_suite ${GMOCK_LIBRARIES} components ${LUA_LIBRARIES}) # Fix for not visible pthreads functions for linker with glibc 2.15 if (UNIX AND NOT APPLE) target_link_libraries(openmw_test_suite ${CMAKE_THREAD_LIBS_INIT}) diff --git a/apps/openmw_test_suite/lua/test_lua.cpp b/apps/openmw_test_suite/lua/test_lua.cpp new file mode 100644 index 0000000000..69a326060c --- /dev/null +++ b/apps/openmw_test_suite/lua/test_lua.cpp @@ -0,0 +1,167 @@ +#include "gmock/gmock.h" +#include + +#include + +#include "testing_util.hpp" + +namespace +{ + using namespace testing; + + TestFile counterFile(R"X( +x = 42 +return { + get = function() return x end, + inc = function(v) x = x + v end +} +)X"); + + TestFile invalidScriptFile("Invalid script"); + + TestFile testsFile(R"X( +return { + -- should work + sin = function(x) return math.sin(x) end, + requireMathSin = function(x) return require('math').sin(x) end, + useCounter = function() + local counter = require('aaa.counter') + counter.inc(1) + return counter.get() + end, + callRawset = function() + t = {a = 1, b = 2} + rawset(t, 'b', 3) + return t.b + end, + print = print, + + -- should throw an error + incorrectRequire = function() require('counter') end, + modifySystemLib = function() math.sin = 5 end, + rawsetSystemLib = function() rawset(math, 'sin', 5) end, + callLoadstring = function() loadstring('print(1)') end, + setSqr = function() require('sqrlib').sqr = math.sin end, + setOmwName = function() require('openmw').name = 'abc' end, + + -- should work if API is registered + sqr = function(x) return require('sqrlib').sqr(x) end, + apiName = function() return require('test.api').name end +} +)X"); + + struct LuaStateTest : Test + { + std::unique_ptr mVFS = createTestVFS({ + {"aaa/counter.lua", &counterFile}, + {"bbb/tests.lua", &testsFile}, + {"invalid.lua", &invalidScriptFile} + }); + + LuaUtil::LuaState mLua{mVFS.get()}; + }; + + TEST_F(LuaStateTest, Sandbox) + { + sol::table script1 = mLua.runInNewSandbox("aaa/counter.lua"); + + EXPECT_EQ(LuaUtil::call(script1["get"]).get(), 42); + LuaUtil::call(script1["inc"], 3); + EXPECT_EQ(LuaUtil::call(script1["get"]).get(), 45); + + sol::table script2 = mLua.runInNewSandbox("aaa/counter.lua"); + EXPECT_EQ(LuaUtil::call(script2["get"]).get(), 42); + LuaUtil::call(script2["inc"], 1); + EXPECT_EQ(LuaUtil::call(script2["get"]).get(), 43); + + EXPECT_EQ(LuaUtil::call(script1["get"]).get(), 45); + } + + TEST_F(LuaStateTest, ErrorHandling) + { + EXPECT_ERROR(mLua.runInNewSandbox("invalid.lua"), "[string \"invalid.lua\"]:1:"); + } + + TEST_F(LuaStateTest, CustomRequire) + { + sol::table script = mLua.runInNewSandbox("bbb/tests.lua"); + + EXPECT_FLOAT_EQ(LuaUtil::call(script["sin"], 1).get(), + -LuaUtil::call(script["requireMathSin"], -1).get()); + + EXPECT_EQ(LuaUtil::call(script["useCounter"]).get(), 43); + EXPECT_EQ(LuaUtil::call(script["useCounter"]).get(), 44); + { + sol::table script2 = mLua.runInNewSandbox("bbb/tests.lua"); + EXPECT_EQ(LuaUtil::call(script2["useCounter"]).get(), 43); + } + EXPECT_EQ(LuaUtil::call(script["useCounter"]).get(), 45); + + EXPECT_ERROR(LuaUtil::call(script["incorrectRequire"]), "Resource 'counter.lua' not found"); + } + + TEST_F(LuaStateTest, ReadOnly) + { + sol::table script = mLua.runInNewSandbox("bbb/tests.lua"); + + // rawset itself is allowed + EXPECT_EQ(LuaUtil::call(script["callRawset"]).get(), 3); + + // but read-only object can not be modified even with rawset + EXPECT_ERROR(LuaUtil::call(script["rawsetSystemLib"]), "bad argument #1 to 'rawset' (table expected, got userdata)"); + EXPECT_ERROR(LuaUtil::call(script["modifySystemLib"]), "a userdata value"); + + EXPECT_EQ(mLua.getMutableFromReadOnly(mLua.makeReadOnly(script)), script); + } + + TEST_F(LuaStateTest, Print) + { + { + sol::table script = mLua.runInNewSandbox("bbb/tests.lua"); + testing::internal::CaptureStdout(); + LuaUtil::call(script["print"], 1, 2, 3); + std::string output = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output, "[bbb/tests.lua]:\t1\t2\t3\n"); + } + { + sol::table script = mLua.runInNewSandbox("bbb/tests.lua", "prefix"); + testing::internal::CaptureStdout(); + LuaUtil::call(script["print"]); // print with no arguments + std::string output = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output, "prefix[bbb/tests.lua]:\n"); + } + } + + TEST_F(LuaStateTest, UnsafeFunction) + { + sol::table script = mLua.runInNewSandbox("bbb/tests.lua"); + EXPECT_ERROR(LuaUtil::call(script["callLoadstring"]), "a nil value"); + } + + TEST_F(LuaStateTest, ProvideAPI) + { + LuaUtil::LuaState lua(mVFS.get()); + + sol::table api1 = lua.makeReadOnly(lua.sol().create_table_with("name", "api1")); + sol::table api2 = lua.makeReadOnly(lua.sol().create_table_with("name", "api2")); + + sol::table script1 = lua.runInNewSandbox("bbb/tests.lua", "", {{"test.api", api1}}); + + lua.addCommonPackage( + "sqrlib", lua.sol().create_table_with("sqr", [](int x) { return x * x; })); + + sol::table script2 = lua.runInNewSandbox("bbb/tests.lua", "", {{"test.api", api2}}); + + EXPECT_ERROR(LuaUtil::call(script1["sqr"], 3), "Resource 'sqrlib.lua' not found"); + EXPECT_EQ(LuaUtil::call(script2["sqr"], 3).get(), 9); + + EXPECT_EQ(LuaUtil::call(script1["apiName"]).get(), "api1"); + EXPECT_EQ(LuaUtil::call(script2["apiName"]).get(), "api2"); + } + + TEST_F(LuaStateTest, GetLuaVersion) + { + EXPECT_THAT(LuaUtil::getLuaVersion(), HasSubstr("Lua")); + } + +} diff --git a/apps/openmw_test_suite/lua/test_utilpackage.cpp b/apps/openmw_test_suite/lua/test_utilpackage.cpp new file mode 100644 index 0000000000..afd9fa2d3c --- /dev/null +++ b/apps/openmw_test_suite/lua/test_utilpackage.cpp @@ -0,0 +1,80 @@ +#include "gmock/gmock.h" +#include + +#include + +#include "testing_util.hpp" + +namespace +{ + using namespace testing; + + TEST(LuaUtilPackageTest, Vector2) + { + sol::state lua; + lua.open_libraries(sol::lib::base, sol::lib::math, sol::lib::string); + lua["util"] = LuaUtil::initUtilPackage(lua); + lua.safe_script("v = util.vector2(3, 4)"); + EXPECT_FLOAT_EQ(lua.safe_script("return v.x").get(), 3); + EXPECT_FLOAT_EQ(lua.safe_script("return v.y").get(), 4); + EXPECT_EQ(lua.safe_script("return tostring(v)").get(), "(3, 4)"); + EXPECT_FLOAT_EQ(lua.safe_script("return v:length()").get(), 5); + EXPECT_FLOAT_EQ(lua.safe_script("return v:length2()").get(), 25); + EXPECT_FALSE(lua.safe_script("return util.vector2(1, 2) == util.vector2(1, 3)").get()); + EXPECT_TRUE(lua.safe_script("return util.vector2(1, 2) + util.vector2(2, 5) == util.vector2(3, 7)").get()); + EXPECT_TRUE(lua.safe_script("return util.vector2(1, 2) - util.vector2(2, 5) == -util.vector2(1, 3)").get()); + EXPECT_TRUE(lua.safe_script("return util.vector2(1, 2) == util.vector2(2, 4) / 2").get()); + EXPECT_TRUE(lua.safe_script("return util.vector2(1, 2) * 2 == util.vector2(2, 4)").get()); + EXPECT_FLOAT_EQ(lua.safe_script("return util.vector2(3, 2) * v").get(), 17); + EXPECT_FLOAT_EQ(lua.safe_script("return util.vector2(3, 2):dot(v)").get(), 17); + EXPECT_ERROR(lua.safe_script("v2, len = v.normalize()"), "value is not a valid userdata"); // checks that it doesn't segfault + lua.safe_script("v2, len = v:normalize()"); + EXPECT_FLOAT_EQ(lua.safe_script("return len").get(), 5); + EXPECT_TRUE(lua.safe_script("return v2 == util.vector2(3/5, 4/5)").get()); + lua.safe_script("_, len = util.vector2(0, 0):normalize()"); + EXPECT_FLOAT_EQ(lua.safe_script("return len").get(), 0); + } + + TEST(LuaUtilPackageTest, Vector3) + { + sol::state lua; + lua.open_libraries(sol::lib::base, sol::lib::math, sol::lib::string); + lua["util"] = LuaUtil::initUtilPackage(lua); + lua.safe_script("v = util.vector3(5, 12, 13)"); + EXPECT_FLOAT_EQ(lua.safe_script("return v.x").get(), 5); + EXPECT_FLOAT_EQ(lua.safe_script("return v.y").get(), 12); + EXPECT_FLOAT_EQ(lua.safe_script("return v.z").get(), 13); + EXPECT_EQ(lua.safe_script("return tostring(v)").get(), "(5, 12, 13)"); + EXPECT_FLOAT_EQ(lua.safe_script("return util.vector3(4, 0, 3):length()").get(), 5); + EXPECT_FLOAT_EQ(lua.safe_script("return util.vector3(4, 0, 3):length2()").get(), 25); + EXPECT_FALSE(lua.safe_script("return util.vector3(1, 2, 3) == util.vector3(1, 3, 2)").get()); + EXPECT_TRUE(lua.safe_script("return util.vector3(1, 2, 3) + util.vector3(2, 5, 1) == util.vector3(3, 7, 4)").get()); + EXPECT_TRUE(lua.safe_script("return util.vector3(1, 2, 3) - util.vector3(2, 5, 1) == -util.vector3(1, 3, -2)").get()); + EXPECT_TRUE(lua.safe_script("return util.vector3(1, 2, 3) == util.vector3(2, 4, 6) / 2").get()); + EXPECT_TRUE(lua.safe_script("return util.vector3(1, 2, 3) * 2 == util.vector3(2, 4, 6)").get()); + EXPECT_FLOAT_EQ(lua.safe_script("return util.vector3(3, 2, 1) * v").get(), 5*3 + 12*2 + 13*1); + EXPECT_FLOAT_EQ(lua.safe_script("return util.vector3(3, 2, 1):dot(v)").get(), 5*3 + 12*2 + 13*1); + EXPECT_TRUE(lua.safe_script("return util.vector3(1, 0, 0) ^ util.vector3(0, 1, 0) == util.vector3(0, 0, 1)").get()); + EXPECT_ERROR(lua.safe_script("v2, len = util.vector3(3, 4, 0).normalize()"), "value is not a valid userdata"); + lua.safe_script("v2, len = util.vector3(3, 4, 0):normalize()"); + EXPECT_FLOAT_EQ(lua.safe_script("return len").get(), 5); + EXPECT_TRUE(lua.safe_script("return v2 == util.vector3(3/5, 4/5, 0)").get()); + lua.safe_script("_, len = util.vector3(0, 0, 0):normalize()"); + EXPECT_FLOAT_EQ(lua.safe_script("return len").get(), 0); + } + + TEST(LuaUtilPackageTest, UtilityFunctions) + { + sol::state lua; + lua.open_libraries(sol::lib::base, sol::lib::math, sol::lib::string); + lua["util"] = LuaUtil::initUtilPackage(lua); + lua.safe_script("v = util.vector2(1, 0):rotate(math.rad(120))"); + EXPECT_FLOAT_EQ(lua.safe_script("return v.x").get(), -0.5); + EXPECT_FLOAT_EQ(lua.safe_script("return v.y").get(), 0.86602539); + EXPECT_FLOAT_EQ(lua.safe_script("return util.normalizeAngle(math.pi * 10 + 0.1)").get(), 0.1); + EXPECT_FLOAT_EQ(lua.safe_script("return util.clamp(0.1, 0, 1.5)").get(), 0.1); + EXPECT_FLOAT_EQ(lua.safe_script("return util.clamp(-0.1, 0, 1.5)").get(), 0); + EXPECT_FLOAT_EQ(lua.safe_script("return util.clamp(2.1, 0, 1.5)").get(), 1.5); + } + +} diff --git a/apps/openmw_test_suite/lua/testing_util.hpp b/apps/openmw_test_suite/lua/testing_util.hpp new file mode 100644 index 0000000000..28c4d59930 --- /dev/null +++ b/apps/openmw_test_suite/lua/testing_util.hpp @@ -0,0 +1,59 @@ +#ifndef LUA_TESTING_UTIL_H +#define LUA_TESTING_UTIL_H + +#include + +#include +#include + +namespace +{ + + class TestFile : public VFS::File + { + public: + explicit TestFile(std::string content) : mContent(std::move(content)) {} + + Files::IStreamPtr open() override + { + return std::make_shared(mContent, std::ios_base::in); + } + + private: + const std::string mContent; + }; + + struct TestData : public VFS::Archive + { + std::map mFiles; + + TestData(std::map files) : mFiles(std::move(files)) {} + + void listResources(std::map& out, char (*normalize_function) (char)) override + { + out = mFiles; + } + + bool contains(const std::string& file, char (*normalize_function) (char)) const override + { + return mFiles.count(file) != 0; + } + + std::string getDescription() const override { return "TestData"; } + + }; + + inline std::unique_ptr createTestVFS(std::map files) + { + auto vfs = std::make_unique(true); + vfs->addArchive(new TestData(std::move(files))); + vfs->buildIndex(); + return vfs; + } + + #define EXPECT_ERROR(X, ERR_SUBSTR) try { X; FAIL() << "Expected error"; } \ + catch (std::exception& e) { EXPECT_THAT(e.what(), HasSubstr(ERR_SUBSTR)); } + +} + +#endif // LUA_TESTING_UTIL_H diff --git a/components/CMakeLists.txt b/components/CMakeLists.txt index 43987d6c7b..c09e118cbe 100644 --- a/components/CMakeLists.txt +++ b/components/CMakeLists.txt @@ -28,6 +28,10 @@ endif (GIT_CHECKOUT) # source files +add_component_dir (lua + luastate utilpackage + ) + add_component_dir (settings settings parser ) diff --git a/components/lua/luastate.cpp b/components/lua/luastate.cpp new file mode 100644 index 0000000000..4020c815d0 --- /dev/null +++ b/components/lua/luastate.cpp @@ -0,0 +1,169 @@ +#include "luastate.hpp" + +#ifndef NO_LUAJIT +#include +#endif // NO_LUAJIT + +#include + +namespace LuaUtil +{ + + static std::string packageNameToPath(std::string_view packageName) + { + std::string res(packageName); + for (size_t i = 0; i < res.size(); ++i) + if (res[i] == '.') + res[i] = '/'; + res.append(".lua"); + return res; + } + + static const std::string safeFunctions[] = { + "assert", "error", "ipairs", "next", "pairs", "pcall", "select", "tonumber", "tostring", + "type", "unpack", "xpcall", "rawequal", "rawget", "rawset", "getmetatable", "setmetatable"}; + static const std::string safePackages[] = {"coroutine", "math", "string", "table"}; + + LuaState::LuaState(const VFS::Manager* vfs) : mVFS(vfs) + { + mLua.open_libraries(sol::lib::base, sol::lib::coroutine, sol::lib::math, sol::lib::string, sol::lib::table); + + mLua["math"]["randomseed"](static_cast(time(NULL))); + mLua["math"]["randomseed"] = sol::nil; + + mLua["writeToLog"] = [](std::string_view s) { Log(Debug::Level::Info) << s; }; + mLua.script(R"(printToLog = function(name, ...) + local msg = name + for _, v in ipairs({...}) do + msg = msg .. '\t' .. tostring(v) + end + return writeToLog(msg) + end)"); + mLua.script("printGen = function(name) return function(...) return printToLog(name, ...) end end"); + + // Some fixes for compatibility between different Lua versions + if (mLua["unpack"] == sol::nil) + mLua["unpack"] = mLua["table"]["unpack"]; + else if (mLua["table"]["unpack"] == sol::nil) + mLua["table"]["unpack"] = mLua["unpack"]; + + mSandboxEnv = sol::table(mLua, sol::create); + mSandboxEnv["_VERSION"] = mLua["_VERSION"]; + for (const std::string& s : safeFunctions) + { + if (mLua[s] == sol::nil) throw std::logic_error("Lua function not found: " + s); + mSandboxEnv[s] = mLua[s]; + } + for (const std::string& s : safePackages) + { + if (mLua[s] == sol::nil) throw std::logic_error("Lua package not found: " + s); + mCommonPackages[s] = mSandboxEnv[s] = makeReadOnly(mLua[s]); + } + } + + LuaState::~LuaState() + { + // Should be cleaned before destructing mLua. + mCommonPackages.clear(); + mSandboxEnv = sol::nil; + } + + sol::table LuaState::makeReadOnly(sol::table table) + { + if (table.is()) + return table; // it is already userdata, no sense to wrap it again + + table[sol::meta_function::index] = table; + sol::stack::push(mLua, std::move(table)); + lua_newuserdata(mLua, 0); + lua_pushvalue(mLua, -2); + lua_setmetatable(mLua, -2); + return sol::stack::pop(mLua); + } + + sol::table LuaState::getMutableFromReadOnly(const sol::userdata& ro) + { + sol::stack::push(mLua, ro); + lua_getmetatable(mLua, -1); + sol::table res = sol::stack::pop(mLua); + lua_pop(mLua, 1); + return res; + } + + void LuaState::addCommonPackage(const std::string& packageName, const sol::object& package) + { + if (package.is()) + mCommonPackages[packageName] = package; + else + mCommonPackages[packageName] = makeReadOnly(package); + } + + sol::protected_function_result LuaState::runInNewSandbox( + const std::string& path, const std::string& namePrefix, + const std::map& packages, const sol::object& hiddenData) + { + sol::protected_function script = loadScript(path); + + sol::environment env(mLua, sol::create, mSandboxEnv); + std::string envName = namePrefix + "[" + path + "]:"; + env["print"] = mLua["printGen"](envName); + + sol::table loaded(mLua, sol::create); + for (const auto& [key, value] : mCommonPackages) + loaded[key] = value; + for (const auto& [key, value] : packages) + loaded[key] = value; + env["require"] = [this, env, loaded, hiddenData](std::string_view packageName) + { + sol::table packages = loaded; + sol::object package = packages[packageName]; + if (package == sol::nil) + { + sol::protected_function packageLoader = loadScript(packageNameToPath(packageName)); + sol::set_environment(env, packageLoader); + package = throwIfError(packageLoader()); + if (!package.is()) + throw std::runtime_error("Lua package must return a table."); + packages[packageName] = package; + } + else if (package.is()) + package = packages[packageName] = call(package.as(), hiddenData); + return package; + }; + + sol::set_environment(env, script); + return call(script); + } + + sol::protected_function_result LuaState::throwIfError(sol::protected_function_result&& res) + { + if (!res.valid() && static_cast(res.get_type()) == LUA_TSTRING) + throw std::runtime_error("Lua error: " + res.get()); + else + return std::move(res); + } + + sol::protected_function LuaState::loadScript(const std::string& path) + { + auto iter = mCompiledScripts.find(path); + if (iter != mCompiledScripts.end()) + return mLua.load(iter->second.as_string_view(), path, sol::load_mode::binary); + + std::string fileContent(std::istreambuf_iterator(*mVFS->get(path)), {}); + sol::load_result res = mLua.load(fileContent, path, sol::load_mode::text); + if (!res.valid()) + throw std::runtime_error("Lua error: " + res.get()); + mCompiledScripts[path] = res.get().dump(); + return res; + } + + std::string getLuaVersion() + { + #ifdef NO_LUAJIT + return LUA_RELEASE; + #else + return LUA_RELEASE " (" LUAJIT_VERSION ")"; + #endif + } + +} diff --git a/components/lua/luastate.hpp b/components/lua/luastate.hpp new file mode 100644 index 0000000000..9cb27fb114 --- /dev/null +++ b/components/lua/luastate.hpp @@ -0,0 +1,107 @@ +#ifndef COMPONENTS_LUA_LUASTATE_H +#define COMPONENTS_LUA_LUASTATE_H + +#include + +#include + +#include + +namespace LuaUtil +{ + + std::string getLuaVersion(); + + // Holds Lua state. + // Provides additional features: + // - Load scripts from the virtual filesystem; + // - Caching of loaded scripts; + // - Disable unsafe Lua functions; + // - Run every instance of every script in a separate sandbox; + // - Forbid any interactions between sandboxes except than via provided API; + // - Access to common read-only resources from different sandboxes; + // - Replace standard `require` with a safe version that allows to search + // Lua libraries (only source, no dll's) in the virtual filesystem; + // - Make `print` to add the script name to the every message and + // write to Log rather than directly to stdout; + class LuaState + { + public: + explicit LuaState(const VFS::Manager* vfs); + ~LuaState(); + + // Returns underlying sol::state. + sol::state& sol() { return mLua; } + + // A shortcut to create a new Lua table. + sol::table newTable() { return sol::table(mLua, sol::create); } + + // Makes a table read only (when accessed from Lua) by wrapping it with an empty userdata. + // Needed to forbid any changes in common resources that can accessed from different sandboxes. + sol::table makeReadOnly(sol::table); + sol::table getMutableFromReadOnly(const sol::userdata&); + + // Registers a package that will be available from every sandbox via `require(name)`. + // The package can be either a sol::table with an API or a sol::function. If it is a function, + // it will be evaluated (once per sandbox) the first time when requested. If the package + // is a table, then `makeReadOnly` is applied to it automatically (but not to other tables it contains). + void addCommonPackage(const std::string& packageName, const sol::object& package); + + // Creates a new sandbox, runs a script, and returns the result + // (the result is expected to be an interface of the script). + // Args: + // path: path to the script in the virtual filesystem; + // namePrefix: sandbox name will be "[]". Sandbox name + // will be added to every `print` output. + // packages: additional packages that should be available from the sandbox via `require`. Each package + // should be either a sol::table or a sol::function. If it is a function, it will be evaluated + // (once per sandbox) with the argument 'hiddenData' the first time when requested. + sol::protected_function_result runInNewSandbox(const std::string& path, + const std::string& namePrefix = "", + const std::map& packages = {}, + const sol::object& hiddenData = sol::nil); + + void dropScriptCache() { mCompiledScripts.clear(); } + + private: + static sol::protected_function_result throwIfError(sol::protected_function_result&&); + template + friend sol::protected_function_result call(sol::protected_function fn, Args&&... args); + + sol::protected_function loadScript(const std::string& path); + + sol::state mLua; + sol::table mSandboxEnv; + std::map mCompiledScripts; + std::map mCommonPackages; + const VFS::Manager* mVFS; + }; + + // Should be used for every call of every Lua function. + // It is a workaround for a bug in `sol`. See https://github.com/ThePhD/sol2/issues/1078 + template + sol::protected_function_result call(sol::protected_function fn, Args&&... args) + { + try + { + return LuaState::throwIfError(fn(std::forward(args)...)); + } + catch (std::exception&) { throw; } + catch (...) { throw std::runtime_error("Unknown error"); } + } + + // getFieldOrNil(table, "a", "b", "c") returns table["a"]["b"]["c"] or nil if some of the fields doesn't exist. + template + sol::object getFieldOrNil(const sol::object& table, std::string_view first, const Str&... str) + { + if (!table.is()) + return sol::nil; + if constexpr (sizeof...(str) == 0) + return table.as()[first]; + else + return getFieldOrNil(table.as()[first], str...); + } + +} + +#endif // COMPONENTS_LUA_LUASTATE_H diff --git a/components/lua/utilpackage.cpp b/components/lua/utilpackage.cpp new file mode 100644 index 0000000000..abcc6d424e --- /dev/null +++ b/components/lua/utilpackage.cpp @@ -0,0 +1,98 @@ +#include "utilpackage.hpp" + +#include +#include + +#include + +#include + +namespace sol +{ + template <> + struct is_automagical : std::false_type {}; + + template <> + struct is_automagical : std::false_type {}; +} + +namespace LuaUtil +{ + + sol::table initUtilPackage(sol::state& lua) + { + sol::table util(lua, sol::create); + + // TODO: Add bindings for osg::Matrix + + // Lua bindings for osg::Vec2f + util["vector2"] = [](float x, float y) { return osg::Vec2f(x, y); }; + sol::usertype vec2Type = lua.new_usertype("Vec2"); + vec2Type["x"] = sol::readonly_property([](const osg::Vec2f& v) -> float { return v.x(); } ); + vec2Type["y"] = sol::readonly_property([](const osg::Vec2f& v) -> float { return v.y(); } ); + vec2Type[sol::meta_function::to_string] = [](const osg::Vec2f& v) { + std::stringstream ss; + ss << "(" << v.x() << ", " << v.y() << ")"; + return ss.str(); + }; + vec2Type[sol::meta_function::unary_minus] = [](const osg::Vec2f& a) { return -a; }; + vec2Type[sol::meta_function::addition] = [](const osg::Vec2f& a, const osg::Vec2f& b) { return a + b; }; + vec2Type[sol::meta_function::subtraction] = [](const osg::Vec2f& a, const osg::Vec2f& b) { return a - b; }; + vec2Type[sol::meta_function::equal_to] = [](const osg::Vec2f& a, const osg::Vec2f& b) { return a == b; }; + vec2Type[sol::meta_function::multiplication] = sol::overload( + [](const osg::Vec2f& a, float c) { return a * c; }, + [](const osg::Vec2f& a, const osg::Vec2f& b) { return a * b; }); + vec2Type[sol::meta_function::division] = [](const osg::Vec2f& a, float c) { return a / c; }; + vec2Type["dot"] = [](const osg::Vec2f& a, const osg::Vec2f& b) { return a * b; }; + vec2Type["length"] = &osg::Vec2f::length; + vec2Type["length2"] = &osg::Vec2f::length2; + vec2Type["normalize"] = [](const osg::Vec2f& v) { + float len = v.length(); + if (len == 0) + return std::make_tuple(osg::Vec2f(), 0.f); + else + return std::make_tuple(v * (1.f / len), len); + }; + vec2Type["rotate"] = &Misc::rotateVec2f; + + // Lua bindings for osg::Vec3f + util["vector3"] = [](float x, float y, float z) { return osg::Vec3f(x, y, z); }; + sol::usertype vec3Type = lua.new_usertype("Vec3"); + vec3Type["x"] = sol::readonly_property([](const osg::Vec3f& v) -> float { return v.x(); } ); + vec3Type["y"] = sol::readonly_property([](const osg::Vec3f& v) -> float { return v.y(); } ); + vec3Type["z"] = sol::readonly_property([](const osg::Vec3f& v) -> float { return v.z(); } ); + vec3Type[sol::meta_function::to_string] = [](const osg::Vec3f& v) { + std::stringstream ss; + ss << "(" << v.x() << ", " << v.y() << ", " << v.z() << ")"; + return ss.str(); + }; + vec3Type[sol::meta_function::unary_minus] = [](const osg::Vec3f& a) { return -a; }; + vec3Type[sol::meta_function::addition] = [](const osg::Vec3f& a, const osg::Vec3f& b) { return a + b; }; + vec3Type[sol::meta_function::subtraction] = [](const osg::Vec3f& a, const osg::Vec3f& b) { return a - b; }; + vec3Type[sol::meta_function::equal_to] = [](const osg::Vec3f& a, const osg::Vec3f& b) { return a == b; }; + vec3Type[sol::meta_function::multiplication] = sol::overload( + [](const osg::Vec3f& a, float c) { return a * c; }, + [](const osg::Vec3f& a, const osg::Vec3f& b) { return a * b; }); + vec3Type[sol::meta_function::division] = [](const osg::Vec3f& a, float c) { return a / c; }; + vec3Type[sol::meta_function::involution] = [](const osg::Vec3f& a, const osg::Vec3f& b) { return a ^ b; }; + vec3Type["dot"] = [](const osg::Vec3f& a, const osg::Vec3f& b) { return a * b; }; + vec3Type["cross"] = [](const osg::Vec3f& a, const osg::Vec3f& b) { return a ^ b; }; + vec3Type["length"] = &osg::Vec3f::length; + vec3Type["length2"] = &osg::Vec3f::length2; + vec3Type["normalize"] = [](const osg::Vec3f& v) { + float len = v.length(); + if (len == 0) + return std::make_tuple(osg::Vec3f(), 0.f); + else + return std::make_tuple(v * (1.f / len), len); + }; + + // Utility functions + util["clamp"] = [](float value, float from, float to) { return std::clamp(value, from, to); }; + // NOTE: `util["clamp"] = std::clamp` causes error 'AddressSanitizer: stack-use-after-scope' + util["normalizeAngle"] = &Misc::normalizeAngle; + + return util; + } + +} diff --git a/components/lua/utilpackage.hpp b/components/lua/utilpackage.hpp new file mode 100644 index 0000000000..06996fb96a --- /dev/null +++ b/components/lua/utilpackage.hpp @@ -0,0 +1,13 @@ +#ifndef COMPONENTS_LUA_UTILPACKAGE_H +#define COMPONENTS_LUA_UTILPACKAGE_H + +#include + +namespace LuaUtil +{ + + sol::table initUtilPackage(sol::state&); + +} + +#endif // COMPONENTS_LUA_UTILPACKAGE_H