Allow deserializing values from serializable UserData using `Lua::from_value()` method.

Closes #240
This commit is contained in:
Alex Orlenko 2023-01-05 21:32:15 +00:00
parent a62061f453
commit 88da28a68d
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
4 changed files with 124 additions and 15 deletions

View File

@ -37,7 +37,7 @@ vendored = ["lua-src", "luajit-src"]
module = ["mlua_derive"]
async = ["futures-core", "futures-task", "futures-util"]
send = []
serialize = ["serde", "erased-serde"]
serialize = ["serde", "erased-serde", "serde-value"]
macros = ["mlua_derive/macros"]
unstable = []
@ -52,6 +52,7 @@ futures-task = { version = "0.3.5", optional = true }
futures-util = { version = "0.3.5", optional = true }
serde = { version = "1.0", optional = true }
erased-serde = { version = "0.3", optional = true }
serde-value = { version = "0.7", optional = true }
parking_lot = { version = "0.12", optional = true }
[build-dependencies]

View File

@ -9,6 +9,7 @@ use serde::de::{self, IntoDeserializer};
use crate::error::{Error, Result};
use crate::table::{Table, TablePairs, TableSequence};
use crate::userdata::AnyUserData;
use crate::value::Value;
/// A struct for deserializing Lua values into Rust values.
@ -131,6 +132,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
Value::Table(_) => self.deserialize_map(visitor),
Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
Value::UserData(ud) if ud.is_serializable() => {
serde_userdata(ud, |value| value.deserialize_any(visitor))
}
Value::Function(_)
| Value::Thread(_)
| Value::UserData(_)
@ -163,8 +167,8 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
#[inline]
fn deserialize_enum<V>(
self,
_name: &str,
_variants: &'static [&'static str],
name: &'static str,
variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
@ -198,6 +202,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
(variant, Some(value), Some(_guard))
}
Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
Value::UserData(ud) if ud.is_serializable() => {
return serde_userdata(ud, |value| value.deserialize_enum(name, variants, visitor));
}
_ => return Err(de::Error::custom("bad enum value")),
};
@ -244,6 +251,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
))
}
}
Value::UserData(ud) if ud.is_serializable() => {
serde_userdata(ud, |value| value.deserialize_seq(visitor))
}
value => Err(de::Error::invalid_type(
de::Unexpected::Other(value.type_name()),
&"table",
@ -299,6 +309,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
))
}
}
Value::UserData(ud) if ud.is_serializable() => {
serde_userdata(ud, |value| value.deserialize_map(visitor))
}
value => Err(de::Error::invalid_type(
de::Unexpected::Other(value.type_name()),
&"table",
@ -320,11 +333,16 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
}
#[inline]
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
match self.value {
Value::UserData(ud) if ud.is_serializable() => {
serde_userdata(ud, |value| value.deserialize_newtype_struct(name, visitor))
}
_ => visitor.visit_newtype_struct(self),
}
}
serde::forward_to_deserialize_any! {
@ -590,6 +608,7 @@ fn check_value_if_skip(
return Ok(true); // skip
}
}
Value::UserData(ud) if ud.is_serializable() => {}
Value::Function(_)
| Value::Thread(_)
| Value::UserData(_)
@ -603,3 +622,11 @@ fn check_value_if_skip(
}
Ok(false) // do not skip
}
fn serde_userdata<V>(
ud: AnyUserData,
f: impl FnOnce(serde_value::Value) -> std::result::Result<V, serde_value::DeserializerError>,
) -> Result<V> {
let value = serde_value::to_value(ud).map_err(|err| Error::SerializeError(err.to_string()))?;
f(value).map_err(|err| Error::DeserializeError(err.to_string()))
}

View File

@ -1029,6 +1029,27 @@ impl<'lua> AnyUserData<'lua> {
Ok(false)
}
/// Returns true if this `AnyUserData` is serializable (eg. was created using `create_ser_userdata`).
#[cfg(feature = "serialize")]
pub(crate) fn is_serializable(&self) -> bool {
let lua = self.0.lua;
let state = lua.state();
let is_serializable = || unsafe {
let _sg = StackGuard::new(state);
check_stack(state, 2)?;
// Userdata can be unregistered or destructed
lua.push_userdata_ref(&self.0)?;
let ud = &*get_userdata::<UserDataCell<()>>(state, -1);
match &*ud.0.try_borrow().map_err(|_| Error::UserDataBorrowError)? {
UserDataWrapped::Default(_) => Result::Ok(false),
UserDataWrapped::Serializable(_) => Result::Ok(true),
}
};
is_serializable().unwrap_or(false)
}
fn inspect<'a, T, F, R>(&'a self, func: F) -> Result<R>
where
T: UserData + 'static,

View File

@ -1,6 +1,7 @@
#![cfg(feature = "serialize")]
use std::collections::HashMap;
use std::error::Error as StdError;
use mlua::{
DeserializeOptions, Error, Lua, LuaSerdeExt, Result as LuaResult, SerializeOptions, UserData,
@ -9,7 +10,7 @@ use mlua::{
use serde::{Deserialize, Serialize};
#[test]
fn test_serialize() -> Result<(), Box<dyn std::error::Error>> {
fn test_serialize() -> Result<(), Box<dyn StdError>> {
#[derive(Serialize)]
struct MyUserData(i64, String);
@ -115,7 +116,7 @@ fn test_serialize_in_scope() -> LuaResult<()> {
}
#[test]
fn test_serialize_failure() -> Result<(), Box<dyn std::error::Error>> {
fn test_serialize_failure() -> Result<(), Box<dyn StdError>> {
#[derive(Serialize)]
struct MyUserData(i64);
@ -146,7 +147,7 @@ fn test_serialize_failure() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(feature = "luau")]
#[test]
fn test_serialize_vector() -> Result<(), Box<dyn std::error::Error>> {
fn test_serialize_vector() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
let globals = lua.globals();
@ -235,7 +236,7 @@ fn test_to_value_enum() -> LuaResult<()> {
}
#[test]
fn test_to_value_with_options() -> Result<(), Box<dyn std::error::Error>> {
fn test_to_value_with_options() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
let globals = lua.globals();
globals.set("null", lua.null())?;
@ -305,7 +306,7 @@ fn test_to_value_with_options() -> Result<(), Box<dyn std::error::Error>> {
}
#[test]
fn test_from_value_nested_tables() -> Result<(), Box<dyn std::error::Error>> {
fn test_from_value_nested_tables() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
let value = lua
@ -335,7 +336,7 @@ fn test_from_value_nested_tables() -> Result<(), Box<dyn std::error::Error>> {
}
#[test]
fn test_from_value_struct() -> Result<(), Box<dyn std::error::Error>> {
fn test_from_value_struct() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
#[derive(Deserialize, PartialEq, Debug)]
@ -376,7 +377,7 @@ fn test_from_value_struct() -> Result<(), Box<dyn std::error::Error>> {
}
#[test]
fn test_from_value_newtype_struct() -> Result<(), Box<dyn std::error::Error>> {
fn test_from_value_newtype_struct() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
#[derive(Deserialize, PartialEq, Debug)]
@ -389,7 +390,7 @@ fn test_from_value_newtype_struct() -> Result<(), Box<dyn std::error::Error>> {
}
#[test]
fn test_from_value_enum() -> Result<(), Box<dyn std::error::Error>> {
fn test_from_value_enum() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
#[derive(Deserialize, PartialEq, Debug)]
@ -420,7 +421,7 @@ fn test_from_value_enum() -> Result<(), Box<dyn std::error::Error>> {
}
#[test]
fn test_from_value_enum_untagged() -> Result<(), Box<dyn std::error::Error>> {
fn test_from_value_enum_untagged() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
lua.globals().set("null", lua.null())?;
@ -460,7 +461,7 @@ fn test_from_value_enum_untagged() -> Result<(), Box<dyn std::error::Error>> {
}
#[test]
fn test_from_value_with_options() -> Result<(), Box<dyn std::error::Error>> {
fn test_from_value_with_options() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
// Deny unsupported types by default
@ -515,3 +516,62 @@ fn test_from_value_with_options() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[test]
fn test_from_value_userdata() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();
// Tuple struct
#[derive(Serialize, Deserialize)]
struct MyUserData(i64, String);
impl UserData for MyUserData {}
let ud = lua.create_ser_userdata(MyUserData(123, "test userdata".into()))?;
match lua.from_value::<MyUserData>(Value::UserData(ud)) {
Ok(_) => {}
Err(err) => panic!("expected no errors, got {err:?}"),
};
// Newtype struct
#[derive(Serialize, Deserialize)]
struct NewtypeUserdata(String);
impl UserData for NewtypeUserdata {}
let ud = lua.create_ser_userdata(NewtypeUserdata("newtype userdata".into()))?;
match lua.from_value::<NewtypeUserdata>(Value::UserData(ud)) {
Ok(_) => {}
Err(err) => panic!("expected no errors, got {err:?}"),
};
// Option
#[derive(Serialize, Deserialize)]
struct UnitUserdata;
impl UserData for UnitUserdata {}
let ud = lua.create_ser_userdata(UnitUserdata)?;
match lua.from_value::<Option<()>>(Value::UserData(ud)) {
Ok(Some(_)) => {}
Ok(_) => panic!("expected `Some`, got `None`"),
Err(err) => panic!("expected no errors, got {err:?}"),
};
// Destructed userdata with skip option
let ud = lua.create_ser_userdata(NewtypeUserdata("newtype userdata".into()))?;
let _ = ud.take::<NewtypeUserdata>()?;
match lua.from_value_with::<()>(
Value::UserData(ud),
DeserializeOptions::new().deny_unsupported_types(false),
) {
Ok(_) => {}
Err(err) => panic!("expected no errors, got {err:?}"),
};
Ok(())
}