From 4f8b418c2322603f3775ef13361d8dbfc1f49180 Mon Sep 17 00:00:00 2001 From: Evil Eye Date: Thu, 30 Oct 2025 22:24:53 +0100 Subject: [PATCH] Ensure LuaUtil::call is invoked from a safe context --- apps/openmw/mwlua/luamanagerimp.cpp | 26 +++++----- components/lua/luastate.hpp | 77 ++++++++++++++++------------- components/lua/scriptscontainer.cpp | 22 +++++++-- components/lua/scriptscontainer.hpp | 1 + components/lua_ui/textedit.cpp | 4 +- components/lua_ui/widget.cpp | 55 ++++++++++++--------- components/lua_ui/widget.hpp | 11 +++-- components/lua_ui/window.cpp | 10 ++-- 8 files changed, 121 insertions(+), 85 deletions(-) diff --git a/apps/openmw/mwlua/luamanagerimp.cpp b/apps/openmw/mwlua/luamanagerimp.cpp index c0146234f5..7aef11caa4 100644 --- a/apps/openmw/mwlua/luamanagerimp.cpp +++ b/apps/openmw/mwlua/luamanagerimp.cpp @@ -222,21 +222,23 @@ namespace MWLua // Run event handlers for events that were sent before `finalizeEventBatch`. mLuaEvents.callEventHandlers(); - // Run queued callbacks - for (CallbackWithData& c : mQueuedCallbacks) - c.mCallback.tryCall(c.mArg); - mQueuedCallbacks.clear(); + mLua.protectedCall([&](LuaUtil::LuaView& lua) { + // Run queued callbacks + for (CallbackWithData& c : mQueuedCallbacks) + c.mCallback.tryCall(c.mArg); + mQueuedCallbacks.clear(); - // Run engine handlers - mEngineEvents.callEngineHandlers(); - bool isPaused = timeManager.isPaused(); + // Run engine handlers + mEngineEvents.callEngineHandlers(); + bool isPaused = timeManager.isPaused(); - float frameDuration = MWBase::Environment::get().getFrameDuration(); - for (LocalScripts* scripts : mActiveLocalScripts) - scripts->update(isPaused ? 0 : frameDuration); - mGlobalScripts.update(isPaused ? 0 : frameDuration); + float frameDuration = MWBase::Environment::get().getFrameDuration(); + for (LocalScripts* scripts : mActiveLocalScripts) + scripts->update(isPaused ? 0 : frameDuration); + mGlobalScripts.update(isPaused ? 0 : frameDuration); - mLua.protectedCall([&](LuaUtil::LuaView& lua) { mScriptTracker.unloadInactiveScripts(lua); }); + mScriptTracker.unloadInactiveScripts(lua); + }); } void LuaManager::objectTeleported(const MWWorld::Ptr& ptr) diff --git a/components/lua/luastate.hpp b/components/lua/luastate.hpp index 178d19aef2..1a432ec749 100644 --- a/components/lua/luastate.hpp +++ b/components/lua/luastate.hpp @@ -50,7 +50,8 @@ namespace LuaUtil } public: - friend class LuaState; + template + friend int invokeProtectedCall(lua_State*, Function&&); // Returns underlying sol::state. sol::state_view& sol() { return mSol; } @@ -67,6 +68,45 @@ namespace LuaUtil return res; } + // Pushing to the stack from outside a Lua context crashes the engine if no memory can be allocated to grow the + // stack + template + [[nodiscard]] int invokeProtectedCall(lua_State* luaState, Function&& function) + { + if (!lua_checkstack(luaState, 2)) + return LUA_ERRMEM; + lua_pushcfunction(luaState, [](lua_State* state) { + void* f = lua_touserdata(state, 1); + LuaView view(state); + (*static_cast(f))(view); + return 0; + }); + lua_pushlightuserdata(luaState, &function); + return lua_pcall(luaState, 1, 0, 0); + } + + template + void protectedCall(lua_State* luaState, Lambda&& f) + { + int result = invokeProtectedCall(luaState, std::forward(f)); + switch (result) + { + case LUA_OK: + break; + case LUA_ERRMEM: + throw std::runtime_error("Lua error: out of memory"); + case LUA_ERRRUN: + { + sol::optional error = sol::stack::check_get(luaState); + if (error) + throw std::runtime_error(*error); + } + [[fallthrough]]; + default: + throw std::runtime_error("Lua error: " + std::to_string(result)); + } + } + // Holds Lua state. // Provides additional features: // - Load scripts from the virtual filesystem; @@ -87,43 +127,10 @@ namespace LuaUtil LuaState(const LuaState&) = delete; LuaState(LuaState&&) = delete; - // Pushing to the stack from outside a Lua context crashes the engine if no memory can be allocated to grow the - // stack - template - [[nodiscard]] int invokeProtectedCall(Function&& function) const - { - if (!lua_checkstack(mSol.lua_state(), 2)) - return LUA_ERRMEM; - lua_pushcfunction(mSol.lua_state(), [](lua_State* state) { - void* f = lua_touserdata(state, 1); - LuaView view(state); - (*static_cast(f))(view); - return 0; - }); - lua_pushlightuserdata(mSol.lua_state(), &function); - return lua_pcall(mSol.lua_state(), 1, 0, 0); - } - template void protectedCall(Lambda&& f) const { - int result = invokeProtectedCall(std::forward(f)); - switch (result) - { - case LUA_OK: - break; - case LUA_ERRMEM: - throw std::runtime_error("Lua error: out of memory"); - case LUA_ERRRUN: - { - sol::optional error = sol::stack::check_get(mSol.lua_state()); - if (error) - throw std::runtime_error(*error); - } - [[fallthrough]]; - default: - throw std::runtime_error("Lua error: " + std::to_string(result)); - } + LuaUtil::protectedCall(mSol.lua_state(), std::forward(f)); } // Note that constructing a sol::state_view is only safe from a Lua context. Use protectedCall to get one diff --git a/components/lua/scriptscontainer.cpp b/components/lua/scriptscontainer.cpp index 36bdbca8b1..6f58e72dd0 100644 --- a/components/lua/scriptscontainer.cpp +++ b/components/lua/scriptscontainer.cpp @@ -407,6 +407,16 @@ namespace LuaUtil } void ScriptsContainer::save(ESM::LuaScripts& data) + { + if (const UnloadedData* unloadedData = std::get_if(&mData)) + { + data.mScripts = unloadedData->mScripts; + return; + } + mLua.protectedCall([&](LuaView& view) { save(view, data); }); + } + + void ScriptsContainer::save(LuaView&, ESM::LuaScripts& data) { if (const UnloadedData* unloadedData = std::get_if(&mData)) { @@ -614,12 +624,12 @@ namespace LuaUtil return data; } - ScriptsContainer::UnloadedData& ScriptsContainer::ensureUnloaded(LuaView&) + ScriptsContainer::UnloadedData& ScriptsContainer::ensureUnloaded(LuaView& view) { if (UnloadedData* data = std::get_if(&mData)) return *data; UnloadedData data; - save(data); + save(view, data); mAPI.erase("openmw.interfaces"); UnloadedData& out = mData.emplace(std::move(data)); for (auto& [_, handlers] : mEngineHandlers) @@ -751,9 +761,11 @@ namespace LuaUtil void ScriptsContainer::processTimers(double simulationTime, double gameTime) { - LoadedData& data = ensureLoaded(); - updateTimerQueue(data.mSimulationTimersQueue, simulationTime); - updateTimerQueue(data.mGameTimersQueue, gameTime); + mLua.protectedCall([&](LuaView& view) { + LoadedData& data = ensureLoaded(); + updateTimerQueue(data.mSimulationTimersQueue, simulationTime); + updateTimerQueue(data.mGameTimersQueue, gameTime); + }); } static constexpr float instructionCountAvgCoef = 1.0f / 30; // averaging over approximately 30 frames diff --git a/components/lua/scriptscontainer.hpp b/components/lua/scriptscontainer.hpp index 275c300ac9..fe036bae8e 100644 --- a/components/lua/scriptscontainer.hpp +++ b/components/lua/scriptscontainer.hpp @@ -286,6 +286,7 @@ namespace LuaUtil static void removeHandler(std::vector& list, int scriptId); void insertInterface(int scriptId, const Script& script); void removeInterface(int scriptId, const Script& script); + void save(LuaView&, ESM::LuaScripts&); ScriptIdsWithInitializationData mAutoStartScripts; const UserdataSerializer* mSerializer = nullptr; diff --git a/components/lua_ui/textedit.cpp b/components/lua_ui/textedit.cpp index 9bd241884a..359b6c6591 100644 --- a/components/lua_ui/textedit.cpp +++ b/components/lua_ui/textedit.cpp @@ -46,7 +46,9 @@ namespace LuaUi void LuaTextEdit::textChange(MyGUI::EditBox*) { - triggerEvent("textChanged", sol::make_object(lua(), mEditBox->getCaption().asUTF8())); + protectedCall([=](LuaUtil::LuaView& view) { + triggerEvent("textChanged", sol::make_object(view.sol(), mEditBox->getCaption().asUTF8())); + }); } void LuaTextEdit::updateCoord() diff --git a/components/lua_ui/widget.cpp b/components/lua_ui/widget.cpp index d9465e0517..d00e7cc5b2 100644 --- a/components/lua_ui/widget.cpp +++ b/components/lua_ui/widget.cpp @@ -169,27 +169,22 @@ namespace LuaUi return result; } - sol::table WidgetExtension::makeTable() const - { - return sol::table(lua(), sol::create); - } - - sol::object WidgetExtension::keyEvent(MyGUI::KeyCode code) const + sol::object WidgetExtension::keyEvent(LuaUtil::LuaView& view, MyGUI::KeyCode code) const { auto keySym = SDL_Keysym(); keySym.sym = SDLUtil::myGuiKeyToSdl(code); keySym.scancode = SDL_GetScancodeFromKey(keySym.sym); keySym.mod = SDL_GetModState(); - return sol::make_object(lua(), keySym); + return sol::make_object(view.sol(), keySym); } sol::object WidgetExtension::mouseEvent( - int left, int top, MyGUI::MouseButton button = MyGUI::MouseButton::None) const + LuaUtil::LuaView& view, int left, int top, MyGUI::MouseButton button = MyGUI::MouseButton::None) const { osg::Vec2f position(left, top); MyGUI::IntPoint absolutePosition = mWidget->getAbsolutePosition(); osg::Vec2f offset = position - osg::Vec2f(absolutePosition.left, absolutePosition.top); - sol::table table = makeTable(); + sol::table table = view.newTable(); int sdlButton = SDLUtil::myGuiMouseButtonToSdl(button); table["position"] = position; table["offset"] = offset; @@ -372,31 +367,39 @@ namespace LuaUi void WidgetExtension::keyPress(MyGUI::Widget*, MyGUI::KeyCode code, MyGUI::Char ch) { - if (code == MyGUI::KeyCode::None) - { - propagateEvent("textInput", [ch](auto w) { - MyGUI::UString uString; - uString.push_back(static_cast(ch)); - return sol::make_object(w->lua(), uString.asUTF8()); - }); - } - else - propagateEvent("keyPress", [code](auto w) { return w->keyEvent(code); }); + protectedCall([=](LuaUtil::LuaView& view) { + if (code == MyGUI::KeyCode::None) + { + propagateEvent("textInput", [&](auto w) { + MyGUI::UString uString; + uString.push_back(static_cast(ch)); + return sol::make_object(view.sol(), uString.asUTF8()); + }); + } + else + propagateEvent("keyPress", [&](auto w) { return w->keyEvent(view, code); }); + }); } void WidgetExtension::keyRelease(MyGUI::Widget*, MyGUI::KeyCode code) { - propagateEvent("keyRelease", [code](auto w) { return w->keyEvent(code); }); + protectedCall([=](LuaUtil::LuaView& view) { + propagateEvent("keyRelease", [&](auto w) { return w->keyEvent(view, code); }); + }); } void WidgetExtension::mouseMove(MyGUI::Widget*, int left, int top) { - propagateEvent("mouseMove", [left, top](auto w) { return w->mouseEvent(left, top); }); + protectedCall([=](LuaUtil::LuaView& view) { + propagateEvent("mouseMove", [&](auto w) { return w->mouseEvent(view, left, top); }); + }); } void WidgetExtension::mouseDrag(MyGUI::Widget*, int left, int top, MyGUI::MouseButton button) { - propagateEvent("mouseMove", [left, top, button](auto w) { return w->mouseEvent(left, top, button); }); + protectedCall([=](LuaUtil::LuaView& view) { + propagateEvent("mouseMove", [&](auto w) { return w->mouseEvent(view, left, top, button); }); + }); } void WidgetExtension::mouseClick(MyGUI::Widget* /*widget*/) @@ -411,12 +414,16 @@ namespace LuaUi void WidgetExtension::mousePress(MyGUI::Widget*, int left, int top, MyGUI::MouseButton button) { - propagateEvent("mousePress", [left, top, button](auto w) { return w->mouseEvent(left, top, button); }); + protectedCall([=](LuaUtil::LuaView& view) { + propagateEvent("mousePress", [&](auto w) { return w->mouseEvent(view, left, top, button); }); + }); } void WidgetExtension::mouseRelease(MyGUI::Widget*, int left, int top, MyGUI::MouseButton button) { - propagateEvent("mouseRelease", [left, top, button](auto w) { return w->mouseEvent(left, top, button); }); + protectedCall([=](LuaUtil::LuaView& view) { + propagateEvent("mouseRelease", [&](auto w) { return w->mouseEvent(view, left, top, button); }); + }); } void WidgetExtension::focusGain(MyGUI::Widget*, MyGUI::Widget*) diff --git a/components/lua_ui/widget.hpp b/components/lua_ui/widget.hpp index 5fcf86d110..c20f70587e 100644 --- a/components/lua_ui/widget.hpp +++ b/components/lua_ui/widget.hpp @@ -83,9 +83,8 @@ namespace LuaUi void registerEvents(MyGUI::Widget* w); void clearEvents(MyGUI::Widget* w); - sol::table makeTable() const; - sol::object keyEvent(MyGUI::KeyCode) const; - sol::object mouseEvent(int left, int top, MyGUI::MouseButton button) const; + sol::object keyEvent(LuaUtil::LuaView& view, MyGUI::KeyCode) const; + sol::object mouseEvent(LuaUtil::LuaView& view, int left, int top, MyGUI::MouseButton button) const; MyGUI::IntSize parentSize() const; virtual MyGUI::IntSize childScalingSize() const; @@ -104,7 +103,11 @@ namespace LuaUi virtual void updateProperties(); virtual void updateChildren() {} - lua_State* lua() const { return mLua; } + template + void protectedCall(Lambda&& f) const + { + LuaUtil::protectedCall(mLua, std::forward(f)); + } void triggerEvent(std::string_view name, sol::object argument) const; template diff --git a/components/lua_ui/window.cpp b/components/lua_ui/window.cpp index a06da14972..94d73ac2c8 100644 --- a/components/lua_ui/window.cpp +++ b/components/lua_ui/window.cpp @@ -78,9 +78,11 @@ namespace LuaUi mPreviousMouse.left = left; mPreviousMouse.top = top; - sol::table table = makeTable(); - table["position"] = osg::Vec2f(mCoord.left, mCoord.top); - table["size"] = osg::Vec2f(mCoord.width, mCoord.height); - triggerEvent("windowDrag", table); + protectedCall([=](LuaUtil::LuaView& view) { + sol::table table = view.newTable(); + table["position"] = osg::Vec2f(mCoord.left, mCoord.top); + table["size"] = osg::Vec2f(mCoord.width, mCoord.height); + triggerEvent("windowDrag", table); + }); } }