1
0
Fork 0
mirror of https://github.com/OpenMW/openmw.git synced 2025-11-29 19:34:32 +00:00

Ensure LuaUtil::call is invoked from a safe context

This commit is contained in:
Evil Eye 2025-10-30 22:24:53 +01:00
parent 8b2be21eea
commit 4f8b418c23
8 changed files with 121 additions and 85 deletions

View file

@ -222,21 +222,23 @@ namespace MWLua
// Run event handlers for events that were sent before `finalizeEventBatch`.
mLuaEvents.callEventHandlers();
// Run queued callbacks
for (CallbackWithData& c : mQueuedCallbacks)
c.mCallback.tryCall(c.mArg);
mQueuedCallbacks.clear();
mLua.protectedCall([&](LuaUtil::LuaView& lua) {
// Run queued callbacks
for (CallbackWithData& c : mQueuedCallbacks)
c.mCallback.tryCall(c.mArg);
mQueuedCallbacks.clear();
// Run engine handlers
mEngineEvents.callEngineHandlers();
bool isPaused = timeManager.isPaused();
// Run engine handlers
mEngineEvents.callEngineHandlers();
bool isPaused = timeManager.isPaused();
float frameDuration = MWBase::Environment::get().getFrameDuration();
for (LocalScripts* scripts : mActiveLocalScripts)
scripts->update(isPaused ? 0 : frameDuration);
mGlobalScripts.update(isPaused ? 0 : frameDuration);
float frameDuration = MWBase::Environment::get().getFrameDuration();
for (LocalScripts* scripts : mActiveLocalScripts)
scripts->update(isPaused ? 0 : frameDuration);
mGlobalScripts.update(isPaused ? 0 : frameDuration);
mLua.protectedCall([&](LuaUtil::LuaView& lua) { mScriptTracker.unloadInactiveScripts(lua); });
mScriptTracker.unloadInactiveScripts(lua);
});
}
void LuaManager::objectTeleported(const MWWorld::Ptr& ptr)

View file

@ -50,7 +50,8 @@ namespace LuaUtil
}
public:
friend class LuaState;
template <class Function>
friend int invokeProtectedCall(lua_State*, Function&&);
// Returns underlying sol::state.
sol::state_view& sol() { return mSol; }
@ -67,6 +68,45 @@ namespace LuaUtil
return res;
}
// Pushing to the stack from outside a Lua context crashes the engine if no memory can be allocated to grow the
// stack
template <class Function>
[[nodiscard]] int invokeProtectedCall(lua_State* luaState, Function&& function)
{
if (!lua_checkstack(luaState, 2))
return LUA_ERRMEM;
lua_pushcfunction(luaState, [](lua_State* state) {
void* f = lua_touserdata(state, 1);
LuaView view(state);
(*static_cast<Function*>(f))(view);
return 0;
});
lua_pushlightuserdata(luaState, &function);
return lua_pcall(luaState, 1, 0, 0);
}
template <class Lambda>
void protectedCall(lua_State* luaState, Lambda&& f)
{
int result = invokeProtectedCall(luaState, std::forward<Lambda>(f));
switch (result)
{
case LUA_OK:
break;
case LUA_ERRMEM:
throw std::runtime_error("Lua error: out of memory");
case LUA_ERRRUN:
{
sol::optional<std::string> error = sol::stack::check_get<std::string>(luaState);
if (error)
throw std::runtime_error(*error);
}
[[fallthrough]];
default:
throw std::runtime_error("Lua error: " + std::to_string(result));
}
}
// Holds Lua state.
// Provides additional features:
// - Load scripts from the virtual filesystem;
@ -87,43 +127,10 @@ namespace LuaUtil
LuaState(const LuaState&) = delete;
LuaState(LuaState&&) = delete;
// Pushing to the stack from outside a Lua context crashes the engine if no memory can be allocated to grow the
// stack
template <class Function>
[[nodiscard]] int invokeProtectedCall(Function&& function) const
{
if (!lua_checkstack(mSol.lua_state(), 2))
return LUA_ERRMEM;
lua_pushcfunction(mSol.lua_state(), [](lua_State* state) {
void* f = lua_touserdata(state, 1);
LuaView view(state);
(*static_cast<Function*>(f))(view);
return 0;
});
lua_pushlightuserdata(mSol.lua_state(), &function);
return lua_pcall(mSol.lua_state(), 1, 0, 0);
}
template <class Lambda>
void protectedCall(Lambda&& f) const
{
int result = invokeProtectedCall(std::forward<Lambda>(f));
switch (result)
{
case LUA_OK:
break;
case LUA_ERRMEM:
throw std::runtime_error("Lua error: out of memory");
case LUA_ERRRUN:
{
sol::optional<std::string> error = sol::stack::check_get<std::string>(mSol.lua_state());
if (error)
throw std::runtime_error(*error);
}
[[fallthrough]];
default:
throw std::runtime_error("Lua error: " + std::to_string(result));
}
LuaUtil::protectedCall(mSol.lua_state(), std::forward<Lambda>(f));
}
// Note that constructing a sol::state_view is only safe from a Lua context. Use protectedCall to get one

View file

@ -407,6 +407,16 @@ namespace LuaUtil
}
void ScriptsContainer::save(ESM::LuaScripts& data)
{
if (const UnloadedData* unloadedData = std::get_if<UnloadedData>(&mData))
{
data.mScripts = unloadedData->mScripts;
return;
}
mLua.protectedCall([&](LuaView& view) { save(view, data); });
}
void ScriptsContainer::save(LuaView&, ESM::LuaScripts& data)
{
if (const UnloadedData* unloadedData = std::get_if<UnloadedData>(&mData))
{
@ -614,12 +624,12 @@ namespace LuaUtil
return data;
}
ScriptsContainer::UnloadedData& ScriptsContainer::ensureUnloaded(LuaView&)
ScriptsContainer::UnloadedData& ScriptsContainer::ensureUnloaded(LuaView& view)
{
if (UnloadedData* data = std::get_if<UnloadedData>(&mData))
return *data;
UnloadedData data;
save(data);
save(view, data);
mAPI.erase("openmw.interfaces");
UnloadedData& out = mData.emplace<UnloadedData>(std::move(data));
for (auto& [_, handlers] : mEngineHandlers)
@ -751,9 +761,11 @@ namespace LuaUtil
void ScriptsContainer::processTimers(double simulationTime, double gameTime)
{
LoadedData& data = ensureLoaded();
updateTimerQueue(data.mSimulationTimersQueue, simulationTime);
updateTimerQueue(data.mGameTimersQueue, gameTime);
mLua.protectedCall([&](LuaView& view) {
LoadedData& data = ensureLoaded();
updateTimerQueue(data.mSimulationTimersQueue, simulationTime);
updateTimerQueue(data.mGameTimersQueue, gameTime);
});
}
static constexpr float instructionCountAvgCoef = 1.0f / 30; // averaging over approximately 30 frames

View file

@ -286,6 +286,7 @@ namespace LuaUtil
static void removeHandler(std::vector<Handler>& list, int scriptId);
void insertInterface(int scriptId, const Script& script);
void removeInterface(int scriptId, const Script& script);
void save(LuaView&, ESM::LuaScripts&);
ScriptIdsWithInitializationData mAutoStartScripts;
const UserdataSerializer* mSerializer = nullptr;

View file

@ -46,7 +46,9 @@ namespace LuaUi
void LuaTextEdit::textChange(MyGUI::EditBox*)
{
triggerEvent("textChanged", sol::make_object(lua(), mEditBox->getCaption().asUTF8()));
protectedCall([=](LuaUtil::LuaView& view) {
triggerEvent("textChanged", sol::make_object(view.sol(), mEditBox->getCaption().asUTF8()));
});
}
void LuaTextEdit::updateCoord()

View file

@ -169,27 +169,22 @@ namespace LuaUi
return result;
}
sol::table WidgetExtension::makeTable() const
{
return sol::table(lua(), sol::create);
}
sol::object WidgetExtension::keyEvent(MyGUI::KeyCode code) const
sol::object WidgetExtension::keyEvent(LuaUtil::LuaView& view, MyGUI::KeyCode code) const
{
auto keySym = SDL_Keysym();
keySym.sym = SDLUtil::myGuiKeyToSdl(code);
keySym.scancode = SDL_GetScancodeFromKey(keySym.sym);
keySym.mod = SDL_GetModState();
return sol::make_object(lua(), keySym);
return sol::make_object(view.sol(), keySym);
}
sol::object WidgetExtension::mouseEvent(
int left, int top, MyGUI::MouseButton button = MyGUI::MouseButton::None) const
LuaUtil::LuaView& view, int left, int top, MyGUI::MouseButton button = MyGUI::MouseButton::None) const
{
osg::Vec2f position(left, top);
MyGUI::IntPoint absolutePosition = mWidget->getAbsolutePosition();
osg::Vec2f offset = position - osg::Vec2f(absolutePosition.left, absolutePosition.top);
sol::table table = makeTable();
sol::table table = view.newTable();
int sdlButton = SDLUtil::myGuiMouseButtonToSdl(button);
table["position"] = position;
table["offset"] = offset;
@ -372,31 +367,39 @@ namespace LuaUi
void WidgetExtension::keyPress(MyGUI::Widget*, MyGUI::KeyCode code, MyGUI::Char ch)
{
if (code == MyGUI::KeyCode::None)
{
propagateEvent("textInput", [ch](auto w) {
MyGUI::UString uString;
uString.push_back(static_cast<MyGUI::UString::unicode_char>(ch));
return sol::make_object(w->lua(), uString.asUTF8());
});
}
else
propagateEvent("keyPress", [code](auto w) { return w->keyEvent(code); });
protectedCall([=](LuaUtil::LuaView& view) {
if (code == MyGUI::KeyCode::None)
{
propagateEvent("textInput", [&](auto w) {
MyGUI::UString uString;
uString.push_back(static_cast<MyGUI::UString::unicode_char>(ch));
return sol::make_object(view.sol(), uString.asUTF8());
});
}
else
propagateEvent("keyPress", [&](auto w) { return w->keyEvent(view, code); });
});
}
void WidgetExtension::keyRelease(MyGUI::Widget*, MyGUI::KeyCode code)
{
propagateEvent("keyRelease", [code](auto w) { return w->keyEvent(code); });
protectedCall([=](LuaUtil::LuaView& view) {
propagateEvent("keyRelease", [&](auto w) { return w->keyEvent(view, code); });
});
}
void WidgetExtension::mouseMove(MyGUI::Widget*, int left, int top)
{
propagateEvent("mouseMove", [left, top](auto w) { return w->mouseEvent(left, top); });
protectedCall([=](LuaUtil::LuaView& view) {
propagateEvent("mouseMove", [&](auto w) { return w->mouseEvent(view, left, top); });
});
}
void WidgetExtension::mouseDrag(MyGUI::Widget*, int left, int top, MyGUI::MouseButton button)
{
propagateEvent("mouseMove", [left, top, button](auto w) { return w->mouseEvent(left, top, button); });
protectedCall([=](LuaUtil::LuaView& view) {
propagateEvent("mouseMove", [&](auto w) { return w->mouseEvent(view, left, top, button); });
});
}
void WidgetExtension::mouseClick(MyGUI::Widget* /*widget*/)
@ -411,12 +414,16 @@ namespace LuaUi
void WidgetExtension::mousePress(MyGUI::Widget*, int left, int top, MyGUI::MouseButton button)
{
propagateEvent("mousePress", [left, top, button](auto w) { return w->mouseEvent(left, top, button); });
protectedCall([=](LuaUtil::LuaView& view) {
propagateEvent("mousePress", [&](auto w) { return w->mouseEvent(view, left, top, button); });
});
}
void WidgetExtension::mouseRelease(MyGUI::Widget*, int left, int top, MyGUI::MouseButton button)
{
propagateEvent("mouseRelease", [left, top, button](auto w) { return w->mouseEvent(left, top, button); });
protectedCall([=](LuaUtil::LuaView& view) {
propagateEvent("mouseRelease", [&](auto w) { return w->mouseEvent(view, left, top, button); });
});
}
void WidgetExtension::focusGain(MyGUI::Widget*, MyGUI::Widget*)

View file

@ -83,9 +83,8 @@ namespace LuaUi
void registerEvents(MyGUI::Widget* w);
void clearEvents(MyGUI::Widget* w);
sol::table makeTable() const;
sol::object keyEvent(MyGUI::KeyCode) const;
sol::object mouseEvent(int left, int top, MyGUI::MouseButton button) const;
sol::object keyEvent(LuaUtil::LuaView& view, MyGUI::KeyCode) const;
sol::object mouseEvent(LuaUtil::LuaView& view, int left, int top, MyGUI::MouseButton button) const;
MyGUI::IntSize parentSize() const;
virtual MyGUI::IntSize childScalingSize() const;
@ -104,7 +103,11 @@ namespace LuaUi
virtual void updateProperties();
virtual void updateChildren() {}
lua_State* lua() const { return mLua; }
template <class Lambda>
void protectedCall(Lambda&& f) const
{
LuaUtil::protectedCall(mLua, std::forward<Lambda>(f));
}
void triggerEvent(std::string_view name, sol::object argument) const;
template <class ArgFactory>

View file

@ -78,9 +78,11 @@ namespace LuaUi
mPreviousMouse.left = left;
mPreviousMouse.top = top;
sol::table table = makeTable();
table["position"] = osg::Vec2f(mCoord.left, mCoord.top);
table["size"] = osg::Vec2f(mCoord.width, mCoord.height);
triggerEvent("windowDrag", table);
protectedCall([=](LuaUtil::LuaView& view) {
sol::table table = view.newTable();
table["position"] = osg::Vec2f(mCoord.left, mCoord.top);
table["size"] = osg::Vec2f(mCoord.width, mCoord.height);
triggerEvent("windowDrag", table);
});
}
}