diff --git a/apps/openmw_test_suite/CMakeLists.txt b/apps/openmw_test_suite/CMakeLists.txt index 3f82df1a52..8c54af0e00 100644 --- a/apps/openmw_test_suite/CMakeLists.txt +++ b/apps/openmw_test_suite/CMakeLists.txt @@ -17,6 +17,7 @@ if (GTEST_FOUND AND GMOCK_FOUND) lua/test_lua.cpp lua/test_utilpackage.cpp + lua/test_serialization.cpp misc/test_stringops.cpp misc/test_endianness.cpp diff --git a/apps/openmw_test_suite/lua/test_serialization.cpp b/apps/openmw_test_suite/lua/test_serialization.cpp new file mode 100644 index 0000000000..d3c01f6298 --- /dev/null +++ b/apps/openmw_test_suite/lua/test_serialization.cpp @@ -0,0 +1,207 @@ +#include "gmock/gmock.h" +#include + +#include +#include + +#include + +#include + +#include "testing_util.hpp" + +namespace +{ + using namespace testing; + + TEST(LuaSerializationTest, Nil) + { + sol::state lua; + EXPECT_EQ(LuaUtil::serialize(sol::nil), ""); + EXPECT_EQ(LuaUtil::deserialize(lua, ""), sol::nil); + } + + TEST(LuaSerializationTest, Number) + { + sol::state lua; + std::string serialized = LuaUtil::serialize(sol::make_object(lua, 3.14)); + EXPECT_EQ(serialized.size(), 10); // version, type, 8 bytes value + sol::object value = LuaUtil::deserialize(lua, serialized); + ASSERT_TRUE(value.is()); + EXPECT_FLOAT_EQ(value.as(), 3.14); + } + + TEST(LuaSerializationTest, Boolean) + { + sol::state lua; + { + std::string serialized = LuaUtil::serialize(sol::make_object(lua, true)); + EXPECT_EQ(serialized.size(), 3); // version, type, 1 byte value + sol::object value = LuaUtil::deserialize(lua, serialized); + EXPECT_FALSE(value.is()); + ASSERT_TRUE(value.is()); + EXPECT_TRUE(value.as()); + } + { + std::string serialized = LuaUtil::serialize(sol::make_object(lua, false)); + EXPECT_EQ(serialized.size(), 3); // version, type, 1 byte value + sol::object value = LuaUtil::deserialize(lua, serialized); + EXPECT_FALSE(value.is()); + ASSERT_TRUE(value.is()); + EXPECT_FALSE(value.as()); + } + } + + TEST(LuaSerializationTest, String) + { + sol::state lua; + std::string_view emptyString = ""; + std::string_view shortString = "abc"; + std::string_view longString = "It is a string with more than 32 characters..........................."; + + { + std::string serialized = LuaUtil::serialize(sol::make_object(lua, emptyString)); + EXPECT_EQ(serialized.size(), 2); // version, type + sol::object value = LuaUtil::deserialize(lua, serialized); + ASSERT_TRUE(value.is()); + EXPECT_EQ(value.as(), emptyString); + } + { + std::string serialized = LuaUtil::serialize(sol::make_object(lua, shortString)); + EXPECT_EQ(serialized.size(), 2 + shortString.size()); // version, type, str data + sol::object value = LuaUtil::deserialize(lua, serialized); + ASSERT_TRUE(value.is()); + EXPECT_EQ(value.as(), shortString); + } + { + std::string serialized = LuaUtil::serialize(sol::make_object(lua, longString)); + EXPECT_EQ(serialized.size(), 6 + longString.size()); // version, type, size, str data + sol::object value = LuaUtil::deserialize(lua, serialized); + ASSERT_TRUE(value.is()); + EXPECT_EQ(value.as(), longString); + } + } + + TEST(LuaSerializationTest, Vector) + { + sol::state lua; + osg::Vec2f vec2(1, 2); + osg::Vec3f vec3(1, 2, 3); + + { + std::string serialized = LuaUtil::serialize(sol::make_object(lua, vec2)); + EXPECT_EQ(serialized.size(), 10); // version, type, 2x float + sol::object value = LuaUtil::deserialize(lua, serialized); + ASSERT_TRUE(value.is()); + EXPECT_EQ(value.as(), vec2); + } + { + std::string serialized = LuaUtil::serialize(sol::make_object(lua, vec3)); + EXPECT_EQ(serialized.size(), 14); // version, type, 3x float + sol::object value = LuaUtil::deserialize(lua, serialized); + ASSERT_TRUE(value.is()); + EXPECT_EQ(value.as(), vec3); + } + } + + TEST(LuaSerializationTest, Table) + { + sol::state lua; + sol::table table(lua, sol::create); + table["aa"] = 1; + table["ab"] = true; + table["nested"] = sol::table(lua, sol::create); + table["nested"]["aa"] = 2; + table["nested"]["bb"] = "something"; + table["nested"][5] = -0.5; + table["nested_empty"] = sol::table(lua, sol::create); + table[1] = osg::Vec2f(1, 2); + table[2] = osg::Vec2f(2, 1); + + std::string serialized = LuaUtil::serialize(table); + EXPECT_EQ(serialized.size(), 123); + sol::table res_table = LuaUtil::deserialize(lua, serialized); + + EXPECT_EQ(res_table.get("aa"), 1); + EXPECT_EQ(res_table.get("ab"), true); + EXPECT_EQ(res_table.get("nested").get("aa"), 2); + EXPECT_EQ(res_table.get("nested").get("bb"), "something"); + EXPECT_FLOAT_EQ(res_table.get("nested").get(5), -0.5); + EXPECT_EQ(res_table.get(1), osg::Vec2f(1, 2)); + EXPECT_EQ(res_table.get(2), osg::Vec2f(2, 1)); + } + + struct TestStruct1 { double a, b; }; + struct TestStruct2 { int a, b; }; + + class TestSerializer final : public LuaUtil::UserdataSerializer + { + bool serialize(LuaUtil::BinaryData& out, const sol::userdata& data) const override + { + if (data.is()) + { + TestStruct1 t = data.as(); + t.a = Misc::toLittleEndian(t.a); + t.b = Misc::toLittleEndian(t.b); + append(out, "ts1", &t, sizeof(t)); + return true; + } + if (data.is()) + { + TestStruct2 t = data.as(); + t.a = Misc::toLittleEndian(t.a); + t.b = Misc::toLittleEndian(t.b); + append(out, "test_struct2", &t, sizeof(t)); + return true; + } + return false; + } + + bool deserialize(std::string_view typeName, std::string_view binaryData, sol::state& lua) const override + { + if (typeName == "ts1") + { + if (sizeof(TestStruct1) != binaryData.size()) + throw std::runtime_error("Incorrect binaryData.size() for TestStruct1: " + std::to_string(binaryData.size())); + TestStruct1 t = *reinterpret_cast(binaryData.data()); + t.a = Misc::fromLittleEndian(t.a); + t.b = Misc::fromLittleEndian(t.b); + sol::stack::push(lua, t); + return true; + } + if (typeName == "test_struct2") + { + if (sizeof(TestStruct2) != binaryData.size()) + throw std::runtime_error("Incorrect binaryData.size() for TestStruct2: " + std::to_string(binaryData.size())); + TestStruct2 t = *reinterpret_cast(binaryData.data()); + t.a = Misc::fromLittleEndian(t.a); + t.b = Misc::fromLittleEndian(t.b); + sol::stack::push(lua, t); + return true; + } + return false; + } + }; + + TEST(LuaSerializationTest, UserdataSerializer) + { + sol::state lua; + sol::table table(lua, sol::create); + table["x"] = TestStruct1{1.5, 2.5}; + table["y"] = TestStruct2{4, 3}; + TestSerializer serializer; + + EXPECT_ERROR(LuaUtil::serialize(table), "Unknown userdata"); + std::string serialized = LuaUtil::serialize(table, &serializer); + EXPECT_ERROR(LuaUtil::deserialize(lua, serialized), "Unknown type:"); + sol::table res = LuaUtil::deserialize(lua, serialized, &serializer); + + TestStruct1 rx = res.get("x"); + TestStruct2 ry = res.get("y"); + EXPECT_EQ(rx.a, 1.5); + EXPECT_EQ(rx.b, 2.5); + EXPECT_EQ(ry.a, 4); + EXPECT_EQ(ry.b, 3); + } + +} diff --git a/components/CMakeLists.txt b/components/CMakeLists.txt index c09e118cbe..6724a9fa29 100644 --- a/components/CMakeLists.txt +++ b/components/CMakeLists.txt @@ -29,7 +29,7 @@ endif (GIT_CHECKOUT) # source files add_component_dir (lua - luastate utilpackage + luastate utilpackage serialization ) add_component_dir (settings diff --git a/components/lua/serialization.cpp b/components/lua/serialization.cpp new file mode 100644 index 0000000000..acbe3c23c7 --- /dev/null +++ b/components/lua/serialization.cpp @@ -0,0 +1,255 @@ +#include "serialization.hpp" + +#include +#include + +#include + +namespace LuaUtil +{ + + constexpr unsigned char FORMAT_VERSION = 0; + + enum class SerializedType + { + NUMBER = 0x0, + LONG_STRING = 0x1, + BOOLEAN = 0x2, + TABLE_START = 0x3, + TABLE_END = 0x4, + + VEC2 = 0x10, + VEC3 = 0x11, + }; + constexpr unsigned char SHORT_STRING_FLAG = 0x20; // 0x001SSSSS. SSSSS = string length + constexpr unsigned char CUSTOM_FULL_FLAG = 0x40; // 0b01TTTTTT + 32bit dataSize + constexpr unsigned char CUSTOM_COMPACT_FLAG = 0x80; // 0b1SSSSTTT. SSSS = dataSize, TTT = (typeName size - 1) + + static void appendType(BinaryData& out, SerializedType type) + { + out.push_back(static_cast(type)); + } + + template + static void appendValue(BinaryData& out, T v) + { + v = Misc::toLittleEndian(v); + out.append(reinterpret_cast(&v), sizeof(v)); + } + + template + static T getValue(std::string_view& binaryData) + { + if (binaryData.size() < sizeof(T)) + throw std::runtime_error("Unexpected end"); + T v; + std::memcpy(&v, binaryData.data(), sizeof(T)); + binaryData = binaryData.substr(sizeof(T)); + return Misc::fromLittleEndian(v); + } + + static void appendString(BinaryData& out, std::string_view str) + { + if (str.size() < 32) + out.push_back(SHORT_STRING_FLAG | char(str.size())); + else + { + appendType(out, SerializedType::LONG_STRING); + appendValue(out, str.size()); + } + out.append(str.data(), str.size()); + } + + static void appendData(BinaryData& out, const void* data, size_t dataSize) + { + out.append(reinterpret_cast(data), dataSize); + } + + void UserdataSerializer::append(BinaryData& out, std::string_view typeName, const void* data, size_t dataSize) + { + assert(!typeName.empty() && typeName.size() <= 64); + if (typeName.size() <= 8 && dataSize < 16) + { // Compact form: 0b1SSSSTTT. SSSS = dataSize, TTT = (typeName size - 1). + unsigned char t = CUSTOM_COMPACT_FLAG | (dataSize << 3) | (typeName.size() - 1); + out.push_back(t); + } + else + { // Full form: 0b01TTTTTT + 32bit dataSize. + unsigned char t = CUSTOM_FULL_FLAG | (typeName.size() - 1); + out.push_back(t); + appendValue(out, dataSize); + } + out.append(typeName.data(), typeName.size()); + appendData(out, data, dataSize); + } + + static void serializeUserdata(BinaryData& out, const sol::userdata& data, const UserdataSerializer* customSerializer) + { + if (data.is()) + { + appendType(out, SerializedType::VEC2); + osg::Vec2f v = data.as(); + appendValue(out, v.x()); + appendValue(out, v.y()); + return; + } + if (data.is()) + { + appendType(out, SerializedType::VEC3); + osg::Vec3f v = data.as(); + appendValue(out, v.x()); + appendValue(out, v.y()); + appendValue(out, v.z()); + return; + } + if (customSerializer && customSerializer->serialize(out, data)) + return; + else + throw std::runtime_error("Unknown userdata"); + } + + static void serialize(BinaryData& out, const sol::object& obj, const UserdataSerializer* customSerializer, int recursionCounter) + { + if (obj.get_type() == sol::type::lightuserdata) + throw std::runtime_error("light userdata is not allowed to be serialized"); + if (obj.is()) + throw std::runtime_error("functions are not allowed to be serialized"); + else if (obj.is()) + serializeUserdata(out, obj, customSerializer); + else if (obj.is()) + { + if (recursionCounter >= 32) + throw std::runtime_error("Can not serialize more than 32 nested tables. Likely the table contains itself."); + sol::table table = obj; + appendType(out, SerializedType::TABLE_START); + for (auto& [key, value] : table) + { + serialize(out, key, customSerializer, recursionCounter + 1); + serialize(out, value, customSerializer, recursionCounter + 1); + } + appendType(out, SerializedType::TABLE_END); + } + else if (obj.is()) + { + appendType(out, SerializedType::NUMBER); + appendValue(out, obj.as()); + } + else if (obj.is()) + appendString(out, obj.as()); + else if (obj.is()) + { + char v = obj.as() ? 1 : 0; + appendType(out, SerializedType::BOOLEAN); + out.push_back(v); + } else + throw std::runtime_error("Unknown lua type"); + } + + static void deserializeImpl(sol::state& lua, std::string_view& binaryData, const UserdataSerializer* customSerializer) + { + if (binaryData.empty()) + throw std::runtime_error("Unexpected end"); + unsigned char type = binaryData[0]; + binaryData = binaryData.substr(1); + if (type & (CUSTOM_COMPACT_FLAG | CUSTOM_FULL_FLAG)) + { + size_t typeNameSize, dataSize; + if (type & CUSTOM_COMPACT_FLAG) + { // Compact form: 0b1SSSSTTT. SSSS = dataSize, TTT = (typeName size - 1). + typeNameSize = (type & 7) + 1; + dataSize = (type >> 3) & 15; + } + else + { // Full form: 0b01TTTTTT + 32bit dataSize. + typeNameSize = (type & 63) + 1; + dataSize = getValue(binaryData); + } + std::string_view typeName = binaryData.substr(0, typeNameSize); + std::string_view data = binaryData.substr(typeNameSize, dataSize); + binaryData = binaryData.substr(typeNameSize + dataSize); + if (!customSerializer || !customSerializer->deserialize(typeName, data, lua)) + throw std::runtime_error("Unknown type: " + std::string(typeName)); + return; + } + if (type & SHORT_STRING_FLAG) + { + size_t size = type & 0x1f; + sol::stack::push(lua.lua_state(), binaryData.substr(0, size)); + binaryData = binaryData.substr(size); + return; + } + switch (static_cast(type)) + { + case SerializedType::NUMBER: + sol::stack::push(lua.lua_state(), getValue(binaryData)); + return; + case SerializedType::BOOLEAN: + sol::stack::push(lua.lua_state(), getValue(binaryData) != 0); + return; + case SerializedType::LONG_STRING: + { + uint32_t size = getValue(binaryData); + sol::stack::push(lua.lua_state(), binaryData.substr(0, size)); + binaryData = binaryData.substr(size); + return; + } + case SerializedType::TABLE_START: + { + lua_createtable(lua, 0, 0); + while (!binaryData.empty() && binaryData[0] != char(SerializedType::TABLE_END)) + { + deserializeImpl(lua, binaryData, customSerializer); + deserializeImpl(lua, binaryData, customSerializer); + lua_settable(lua, -3); + } + if (binaryData.empty()) + throw std::runtime_error("Unexpected end"); + binaryData = binaryData.substr(1); + return; + } + case SerializedType::TABLE_END: + throw std::runtime_error("Unexpected table end"); + case SerializedType::VEC2: + { + float x = getValue(binaryData); + float y = getValue(binaryData); + sol::stack::push(lua.lua_state(), osg::Vec2f(x, y)); + return; + } + case SerializedType::VEC3: + { + float x = getValue(binaryData); + float y = getValue(binaryData); + float z = getValue(binaryData); + sol::stack::push(lua.lua_state(), osg::Vec3f(x, y, z)); + return; + } + default: throw std::runtime_error("Unknown type: " + std::to_string(type)); + } + } + + BinaryData serialize(const sol::object& obj, const UserdataSerializer* customSerializer) + { + if (obj == sol::nil) + return ""; + BinaryData res; + res.push_back(FORMAT_VERSION); + serialize(res, obj, customSerializer, 0); + return res; + } + + sol::object deserialize(sol::state& lua, std::string_view binaryData, const UserdataSerializer* customSerializer) + { + if (binaryData.empty()) + return sol::nil; + if (binaryData[0] != FORMAT_VERSION) + throw std::runtime_error("Incorrect version of Lua serialization format: " + + std::to_string(static_cast(binaryData[0]))); + binaryData = binaryData.substr(1); + deserializeImpl(lua, binaryData, customSerializer); + if (!binaryData.empty()) + throw std::runtime_error("Unexpected data after serialized object"); + return sol::stack::pop(lua.lua_state()); + } + +} diff --git a/components/lua/serialization.hpp b/components/lua/serialization.hpp new file mode 100644 index 0000000000..63f93baac8 --- /dev/null +++ b/components/lua/serialization.hpp @@ -0,0 +1,34 @@ +#ifndef COMPONENTS_LUA_SERIALIZATION_H +#define COMPONENTS_LUA_SERIALIZATION_H + +#include + +namespace LuaUtil +{ + + // Note: it can contain \0 + using BinaryData = std::string; + + class UserdataSerializer + { + public: + virtual ~UserdataSerializer() {} + + // Appends serialized sol::userdata to the end of BinaryData. + // Returns false if this type of userdata is not supported by this serializer. + virtual bool serialize(BinaryData&, const sol::userdata&) const = 0; + + // Deserializes userdata of type "typeName" from binaryData. Should push the result on stack using sol::stack::push. + // Returns false if this type is not supported by this serializer. + virtual bool deserialize(std::string_view typeName, std::string_view binaryData, sol::state&) const = 0; + + protected: + static void append(BinaryData&, std::string_view typeName, const void* data, size_t dataSize); + }; + + BinaryData serialize(const sol::object&, const UserdataSerializer* customSerializer = nullptr); + sol::object deserialize(sol::state& lua, std::string_view binaryData, const UserdataSerializer* customSerializer = nullptr); + +} + +#endif // COMPONENTS_LUA_SERIALIZATION_H