From 61578ca39beebd69fd25c1b06708ac178fcbb756 Mon Sep 17 00:00:00 2001 From: Petr Mikheev Date: Tue, 25 Apr 2023 21:07:59 +0200 Subject: [PATCH 1/2] Add function LuaUtil::cast --- apps/openmw_test_suite/lua/test_lua.cpp | 17 +++++++++++++++++ components/lua/luastate.cpp | 21 +++++++++++++++++++++ components/lua/luastate.hpp | 22 ++++++++++++++++------ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/apps/openmw_test_suite/lua/test_lua.cpp b/apps/openmw_test_suite/lua/test_lua.cpp index 76d430b440..90c987522d 100644 --- a/apps/openmw_test_suite/lua/test_lua.cpp +++ b/apps/openmw_test_suite/lua/test_lua.cpp @@ -100,6 +100,23 @@ return { EXPECT_EQ(LuaUtil::toString(sol::make_object(mLua.sol(), "something")), "\"something\""); } + TEST_F(LuaStateTest, Cast) + { + EXPECT_EQ(LuaUtil::cast(sol::make_object(mLua.sol(), 3.14)), 3); + EXPECT_ERROR( + LuaUtil::cast(sol::make_object(mLua.sol(), "3.14")), "Value \"\"3.14\"\" can not be casted to int"); + EXPECT_ERROR(LuaUtil::cast(sol::make_object(mLua.sol(), sol::nil)), + "Value \"nil\" can not be casted to string"); + EXPECT_ERROR(LuaUtil::cast(sol::make_object(mLua.sol(), sol::nil)), + "Value \"nil\" can not be casted to string"); + EXPECT_ERROR(LuaUtil::cast(sol::make_object(mLua.sol(), sol::nil)), + "Value \"nil\" can not be casted to sol::table"); + EXPECT_ERROR(LuaUtil::cast(sol::make_object(mLua.sol(), "3.14")), + "Value \"\"3.14\"\" can not be casted to sol::function"); + EXPECT_ERROR(LuaUtil::cast(sol::make_object(mLua.sol(), "3.14")), + "Value \"\"3.14\"\" can not be casted to sol::function"); + } + TEST_F(LuaStateTest, ErrorHandling) { EXPECT_ERROR(mLua.runInNewSandbox("invalid.lua"), "[string \"invalid.lua\"]:1:"); diff --git a/components/lua/luastate.cpp b/components/lua/luastate.cpp index 453e9d1586..a8818ceee0 100644 --- a/components/lua/luastate.cpp +++ b/components/lua/luastate.cpp @@ -420,4 +420,25 @@ namespace LuaUtil return call(sol::state_view(obj.lua_state())["tostring"], obj); } + std::string internal::formatCastingError(const sol::object& obj, const std::type_info& t) + { + const char* typeName = t.name(); + if (t == typeid(int)) + typeName = "int"; + else if (t == typeid(unsigned)) + typeName = "uint32"; + else if (t == typeid(size_t)) + typeName = "size_t"; + else if (t == typeid(float)) + typeName = "float"; + else if (t == typeid(double)) + typeName = "double"; + else if (t == typeid(sol::table)) + typeName = "sol::table"; + else if (t == typeid(sol::function) || t == typeid(sol::protected_function)) + typeName = "sol::function"; + else if (t == typeid(std::string) || t == typeid(std::string_view)) + typeName = "string"; + return std::string("Value \"") + toString(obj) + std::string("\" can not be casted to ") + typeName; + } } diff --git a/components/lua/luastate.hpp b/components/lua/luastate.hpp index a41a29d283..aea1e32590 100644 --- a/components/lua/luastate.hpp +++ b/components/lua/luastate.hpp @@ -1,12 +1,12 @@ #ifndef COMPONENTS_LUA_LUASTATE_H #define COMPONENTS_LUA_LUASTATE_H +#include #include +#include #include -#include - #include "configuration.hpp" namespace VFS @@ -247,15 +247,25 @@ namespace LuaUtil // String representation of a Lua object. Should be used for debugging/logging purposes only. std::string toString(const sol::object&); + namespace internal + { + std::string formatCastingError(const sol::object& obj, const std::type_info&); + } + + template + decltype(auto) cast(const sol::object& obj) + { + if (!obj.is()) + throw std::runtime_error(internal::formatCastingError(obj, typeid(T))); + return obj.as(); + } + template T getValueOrDefault(const sol::object& obj, const T& defaultValue) { if (obj == sol::nil) return defaultValue; - if (obj.is()) - return obj.as(); - else - throw std::logic_error(std::string("Value \"") + toString(obj) + std::string("\" has unexpected type")); + return cast(obj); } // Makes a table read only (when accessed from Lua) by wrapping it with an empty userdata. From c362b2efa6bac8f47f49397fd503721e3004c2f6 Mon Sep 17 00:00:00 2001 From: Petr Mikheev Date: Tue, 25 Apr 2023 22:11:04 +0200 Subject: [PATCH 2/2] Use LuaUtil::cast for casting sol::object to prevent crashing on type mismatch in Lua scripts. --- apps/openmw/mwlua/magicbindings.cpp | 2 +- apps/openmw/mwlua/objectbindings.cpp | 4 ++-- apps/openmw/mwlua/stats.cpp | 20 +++++++++---------- apps/openmw/mwlua/types/actor.cpp | 6 +++--- components/lua/l10n.cpp | 9 +++++---- components/lua/scriptscontainer.cpp | 29 ++++++++++++++-------------- components/lua/serialization.cpp | 2 +- components/lua/storage.cpp | 8 ++++---- components/lua/utilpackage.cpp | 6 +++--- components/lua_ui/content.hpp | 4 ++-- components/lua_ui/element.cpp | 2 +- components/lua_ui/element.hpp | 2 +- 12 files changed, 48 insertions(+), 46 deletions(-) diff --git a/apps/openmw/mwlua/magicbindings.cpp b/apps/openmw/mwlua/magicbindings.cpp index f99e571c1f..1e7df15748 100644 --- a/apps/openmw/mwlua/magicbindings.cpp +++ b/apps/openmw/mwlua/magicbindings.cpp @@ -230,7 +230,7 @@ namespace MWLua if (spellOrId.is()) return spellOrId.as()->mId; else - return ESM::RefId::deserializeText(spellOrId.as()); + return ESM::RefId::deserializeText(LuaUtil::cast(spellOrId)); }; // types.Actor.spells(o):add(id) diff --git a/apps/openmw/mwlua/objectbindings.cpp b/apps/openmw/mwlua/objectbindings.cpp index 68db67f0b7..fcce6f6723 100644 --- a/apps/openmw/mwlua/objectbindings.cpp +++ b/apps/openmw/mwlua/objectbindings.cpp @@ -55,7 +55,7 @@ namespace MWLua cell = cellOrName.as().mStore; else { - std::string_view name = cellOrName.as(); + std::string_view name = LuaUtil::cast(cellOrName); if (name.empty()) cell = nullptr; // default exterior worldspace else @@ -195,7 +195,7 @@ namespace MWLua throw std::runtime_error("Attaching scripts to Static is not allowed: " + std::string(path)); if (initData != sol::nil) context.mLuaManager->addCustomLocalScript(object.ptr(), *scriptId, - LuaUtil::serialize(initData.as(), context.mSerializer)); + LuaUtil::serialize(LuaUtil::cast(initData), context.mSerializer)); else context.mLuaManager->addCustomLocalScript( object.ptr(), *scriptId, cfg[*scriptId].mInitializationData); diff --git a/apps/openmw/mwlua/stats.cpp b/apps/openmw/mwlua/stats.cpp index eeeb11cb91..54fd953b35 100644 --- a/apps/openmw/mwlua/stats.cpp +++ b/apps/openmw/mwlua/stats.cpp @@ -123,7 +123,7 @@ namespace MWLua { auto& stats = ptr.getClass().getCreatureStats(ptr); if (prop == "current") - stats.setLevel(value.as()); + stats.setLevel(LuaUtil::cast(value)); } }; @@ -167,7 +167,7 @@ namespace MWLua { auto& stats = ptr.getClass().getCreatureStats(ptr); auto stat = stats.getDynamic(index); - float floatValue = value.as(); + float floatValue = LuaUtil::cast(value); if (prop == "base") stat.setBase(floatValue); else if (prop == "current") @@ -201,9 +201,9 @@ namespace MWLua float getModified(const Context& context) const { - auto base = get(context, "base", &MWMechanics::AttributeValue::getBase).as(); - auto damage = get(context, "damage", &MWMechanics::AttributeValue::getDamage).as(); - auto modifier = get(context, "modifier", &MWMechanics::AttributeValue::getModifier).as(); + auto base = LuaUtil::cast(get(context, "base", &MWMechanics::AttributeValue::getBase)); + auto damage = LuaUtil::cast(get(context, "damage", &MWMechanics::AttributeValue::getDamage)); + auto modifier = LuaUtil::cast(get(context, "modifier", &MWMechanics::AttributeValue::getModifier)); return std::max(0.f, base - damage + modifier); // Should match AttributeValue::getModified } @@ -226,7 +226,7 @@ namespace MWLua { auto& stats = ptr.getClass().getCreatureStats(ptr); auto stat = stats.getAttribute(index); - float floatValue = value.as(); + float floatValue = LuaUtil::cast(value); if (prop == "base") stat.setBase(floatValue); else if (prop == "damage") @@ -278,9 +278,9 @@ namespace MWLua float getModified(const Context& context) const { - auto base = get(context, "base", &MWMechanics::SkillValue::getBase).as(); - auto damage = get(context, "damage", &MWMechanics::SkillValue::getDamage).as(); - auto modifier = get(context, "modifier", &MWMechanics::SkillValue::getModifier).as(); + auto base = LuaUtil::cast(get(context, "base", &MWMechanics::SkillValue::getBase)); + auto damage = LuaUtil::cast(get(context, "damage", &MWMechanics::SkillValue::getDamage)); + auto modifier = LuaUtil::cast(get(context, "modifier", &MWMechanics::SkillValue::getModifier)); return std::max(0.f, base - damage + modifier); // Should match SkillValue::getModified } @@ -311,7 +311,7 @@ namespace MWLua { auto& stats = ptr.getClass().getNpcStats(ptr); auto stat = stats.getSkill(index); - float floatValue = value.as(); + float floatValue = LuaUtil::cast(value); if (prop == "base") stat.setBase(floatValue); else if (prop == "damage") diff --git a/apps/openmw/mwlua/types/actor.cpp b/apps/openmw/mwlua/types/actor.cpp index 1672786648..adff313f49 100644 --- a/apps/openmw/mwlua/types/actor.cpp +++ b/apps/openmw/mwlua/types/actor.cpp @@ -272,11 +272,11 @@ namespace MWLua SetEquipmentAction::Equipment eqp; for (auto& [key, value] : equipment) { - int slot = key.as(); + int slot = LuaUtil::cast(key); if (value.is()) - eqp[slot] = value.as().id(); + eqp[slot] = LuaUtil::cast(value).id(); else - eqp[slot] = value.as(); + eqp[slot] = LuaUtil::cast(value); } context.mLuaManager->addAction( std::make_unique(context.mLua, obj.id(), std::move(eqp))); diff --git a/components/lua/l10n.cpp b/components/lua/l10n.cpp index 9bfad15ad8..542c81009a 100644 --- a/components/lua/l10n.cpp +++ b/components/lua/l10n.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace { @@ -17,20 +18,20 @@ namespace { // Argument values if (value.is()) - args.push_back(icu::Formattable(value.as().c_str())); + args.push_back(icu::Formattable(LuaUtil::cast(value).c_str())); // Note: While we pass all numbers as doubles, they still seem to be handled appropriately. // Numbers can be forced to be integers using the argType number and argStyle integer // E.g. {var, number, integer} else if (value.is()) - args.push_back(icu::Formattable(value.as())); + args.push_back(icu::Formattable(LuaUtil::cast(value))); else { - Log(Debug::Error) << "Unrecognized argument type for key \"" << key.as() + Log(Debug::Error) << "Unrecognized argument type for key \"" << LuaUtil::cast(key) << "\" when formatting message \"" << messageId << "\""; } // Argument names - const auto str = key.as(); + const auto str = LuaUtil::cast(key); argNames.push_back(icu::UnicodeString::fromUTF8(icu::StringPiece(str.data(), str.size()))); } } diff --git a/components/lua/scriptscontainer.cpp b/components/lua/scriptscontainer.cpp index eaf8c27af5..e6cbfec791 100644 --- a/components/lua/scriptscontainer.cpp +++ b/components/lua/scriptscontainer.cpp @@ -94,33 +94,34 @@ namespace LuaUtil if (scriptOutput == sol::nil) return true; sol::object engineHandlers = sol::nil, eventHandlers = sol::nil; - for (const auto& [key, value] : sol::table(scriptOutput)) + for (const auto& [key, value] : cast(scriptOutput)) { - std::string_view sectionName = key.as(); + std::string_view sectionName = cast(key); if (sectionName == ENGINE_HANDLERS) engineHandlers = value; else if (sectionName == EVENT_HANDLERS) eventHandlers = value; else if (sectionName == INTERFACE_NAME) - script.mInterfaceName = value.as(); + script.mInterfaceName = cast(value); else if (sectionName == INTERFACE) - script.mInterface = value.as(); + script.mInterface = cast(value); else Log(Debug::Error) << "Not supported section '" << sectionName << "' in " << debugName; } if (engineHandlers != sol::nil) { - for (const auto& [key, fn] : sol::table(engineHandlers)) + for (const auto& [key, handler] : cast(engineHandlers)) { - std::string_view handlerName = key.as(); + std::string_view handlerName = cast(key); + sol::function fn = cast(handler); if (handlerName == HANDLER_INIT) - onInit = sol::function(fn); + onInit = fn; else if (handlerName == HANDLER_LOAD) - onLoad = sol::function(fn); + onLoad = fn; else if (handlerName == HANDLER_SAVE) - script.mOnSave = sol::function(fn); + script.mOnSave = fn; else if (handlerName == HANDLER_INTERFACE_OVERRIDE) - script.mOnOverride = sol::function(fn); + script.mOnOverride = fn; else { auto it = mEngineHandlers.find(handlerName); @@ -133,13 +134,13 @@ namespace LuaUtil } if (eventHandlers != sol::nil) { - for (const auto& [key, fn] : sol::table(eventHandlers)) + for (const auto& [key, fn] : cast(eventHandlers)) { - std::string_view eventName = key.as(); + std::string_view eventName = cast(key); auto it = mEventHandlers.find(eventName); if (it == mEventHandlers.end()) it = mEventHandlers.emplace(std::string(eventName), EventHandlerList()).first; - insertHandler(it->second, scriptId, fn); + insertHandler(it->second, scriptId, cast(fn)); } } @@ -318,7 +319,7 @@ namespace LuaUtil try { sol::object res = LuaUtil::call({ this, h.mScriptId }, h.mFn, data); - if (res != sol::nil && !res.as()) + if (res.is() && !res.as()) break; // Skip other handlers if 'false' was returned. } catch (std::exception& e) diff --git a/components/lua/serialization.cpp b/components/lua/serialization.cpp index 3976a3f94e..2a66702589 100644 --- a/components/lua/serialization.cpp +++ b/components/lua/serialization.cpp @@ -106,7 +106,7 @@ namespace LuaUtil bool BasicSerializer::serialize(BinaryData& out, const sol::userdata& data) const { - appendRefNum(out, data.as()); + appendRefNum(out, cast(data)); return true; } diff --git a/components/lua/storage.cpp b/components/lua/storage.cpp index 3932a43280..d23e1eb3d7 100644 --- a/components/lua/storage.cpp +++ b/components/lua/storage.cpp @@ -85,7 +85,7 @@ namespace LuaUtil if (values) { for (const auto& [k, v] : *values) - mValues[k.as()] = Value(v); + mValues[cast(k)] = Value(v); } if (mStorage->mListener) mStorage->mListener->sectionReplaced(mSectionName, values); @@ -166,9 +166,9 @@ namespace LuaUtil sol::table data = deserialize(mLua, serializedData); for (const auto& [sectionName, sectionTable] : data) { - const std::shared_ptr
& section = getSection(sectionName.as()); - for (const auto& [key, value] : sol::table(sectionTable)) - section->set(key.as(), value); + const std::shared_ptr
& section = getSection(cast(sectionName)); + for (const auto& [key, value] : cast(sectionTable)) + section->set(cast(key), value); } } catch (std::exception& e) diff --git a/components/lua/utilpackage.cpp b/components/lua/utilpackage.cpp index ddaebf9d58..932eb21d50 100644 --- a/components/lua/utilpackage.cpp +++ b/components/lua/utilpackage.cpp @@ -236,17 +236,17 @@ namespace LuaUtil { util["bitOr"] = [](unsigned a, sol::variadic_args va) { for (const auto& v : va) - a |= v.as(); + a |= cast(v); return a; }; util["bitAnd"] = [](unsigned a, sol::variadic_args va) { for (const auto& v : va) - a &= v.as(); + a &= cast(v); return a; }; util["bitXor"] = [](unsigned a, sol::variadic_args va) { for (const auto& v : va) - a ^= v.as(); + a ^= cast(v); return a; }; util["bitNot"] = [](unsigned a) { return ~a; }; diff --git a/components/lua_ui/content.hpp b/components/lua_ui/content.hpp index 2caa1ff8dc..c8bb82ecf3 100644 --- a/components/lua_ui/content.hpp +++ b/components/lua_ui/content.hpp @@ -78,7 +78,7 @@ namespace LuaUi { sol::object result = callMethod("indexOf", name); if (result.is()) - return fromLua(result.as()); + return fromLua(LuaUtil::cast(result)); else return std::nullopt; } @@ -86,7 +86,7 @@ namespace LuaUi { sol::object result = callMethod("indexOf", table); if (result.is()) - return fromLua(result.as()); + return fromLua(LuaUtil::cast(result)); else return std::nullopt; } diff --git a/components/lua_ui/element.cpp b/components/lua_ui/element.cpp index 13877b2371..6f06f55cd8 100644 --- a/components/lua_ui/element.cpp +++ b/components/lua_ui/element.cpp @@ -63,7 +63,7 @@ namespace LuaUi destroyWidget(w); return result; } - ContentView content(contentObj.as()); + ContentView content(LuaUtil::cast(contentObj)); result.resize(content.size()); size_t minSize = std::min(children.size(), content.size()); for (size_t i = 0; i < minSize; i++) diff --git a/components/lua_ui/element.hpp b/components/lua_ui/element.hpp index 1aee1c0506..b57af92fee 100644 --- a/components/lua_ui/element.hpp +++ b/components/lua_ui/element.hpp @@ -36,7 +36,7 @@ namespace LuaUi private: Element(sol::table layout); - sol::table layout() { return mLayout.as(); } + sol::table layout() { return LuaUtil::cast(mLayout); } static std::map> sAllElements; void updateAttachment(); };