diff --git a/components/lua/luastate.cpp b/components/lua/luastate.cpp index 9af617020e..453e9d1586 100644 --- a/components/lua/luastate.cpp +++ b/components/lua/luastate.cpp @@ -183,6 +183,13 @@ namespace LuaUtil mSol["writeToLog"] = [](std::string_view s) { Log(Debug::Level::Info) << s; }; + mSol["setEnvironment"] + = [](const sol::environment& env, const sol::function& fn) { sol::set_environment(env, fn); }; + mSol["loadFromVFS"] = [this](std::string_view packageName) { + return loadScriptAndCache(packageNameToVfsPath(packageName, mVFS)); + }; + mSol["loadInternalLib"] = [this](std::string_view packageName) { return loadInternalLib(packageName); }; + // Some fixes for compatibility between different Lua versions if (mSol["unpack"] == sol::nil) mSol["unpack"] = mSol["table"]["unpack"]; @@ -208,6 +215,19 @@ namespace LuaUtil end printGen = function(name) return function(...) return printToLog(name, ...) end end + function requireGen(env, loaded, loadFn) + return function(packageName) + local p = loaded[packageName] + if p == nil then + local loader = loadFn(packageName) + setEnvironment(env, loader) + p = loader(packageName) + loaded[packageName] = p + end + return p + end + end + function createStrictIndexFn(tbl) return function(_, key) local res = tbl[key] @@ -327,16 +347,7 @@ namespace LuaUtil loaded[key] = maybeRunLoader(value); for (const auto& [key, value] : packages) loaded[key] = maybeRunLoader(value); - env["require"] = [this, env, loaded, hiddenData](std::string_view packageName) mutable { - sol::object package = loaded[packageName]; - if (package != sol::nil) - return package; - sol::protected_function packageLoader = loadScriptAndCache(packageNameToVfsPath(packageName, mVFS)); - sol::set_environment(env, packageLoader); - package = call(packageLoader, packageName); - loaded[packageName] = package; - return package; - }; + env["require"] = mSol["requireGen"](env, loaded, mSol["loadFromVFS"]); sol::set_environment(env, script); return call(scriptId, script); @@ -348,14 +359,7 @@ namespace LuaUtil sol::table loaded(mSol, sol::create); for (const std::string& s : safePackages) loaded[s] = static_cast(mSandboxEnv[s]); - env["require"] = [this, loaded, env](const std::string& module) mutable { - if (loaded[module] != sol::nil) - return loaded[module]; - sol::protected_function initializer = loadInternalLib(module); - sol::set_environment(env, initializer); - loaded[module] = call({}, initializer, module); - return loaded[module]; - }; + env["require"] = mSol["requireGen"](env, loaded, mSol["loadInternalLib"]); return env; }