From 68b60e2a0a24b73af4cb9dd7258953bdc8d689a1 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 26 Feb 2023 21:52:28 +0000 Subject: [PATCH] Add `UserDataRef` and `UserDataRefMut` types that implement `FromLua` and can be used as accessors to underlying `AnyUserData` type. --- src/conversion.rs | 16 +++++++++++- src/lib.rs | 1 + src/userdata.rs | 66 +++++++++++++++++++++++++++++++++++++++++++++-- tests/userdata.rs | 54 ++++++++++++++++++++------------------ 4 files changed, 109 insertions(+), 28 deletions(-) diff --git a/src/conversion.rs b/src/conversion.rs index c4ff6c2..028957c 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -15,7 +15,7 @@ use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::types::{LightUserData, MaybeSend}; -use crate::userdata::{AnyUserData, UserData}; +use crate::userdata::{AnyUserData, UserData, UserDataRef, UserDataRefMut}; use crate::value::{FromLua, IntoLua, Nil, Value}; #[cfg(feature = "unstable")] @@ -217,6 +217,20 @@ impl<'lua, T: 'static + MaybeSend + UserData> IntoLua<'lua> for T { } } +impl<'lua, T: 'static> FromLua<'lua> for UserDataRef<'lua, T> { + #[inline] + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { + Self::from_value(value) + } +} + +impl<'lua, T: 'static> FromLua<'lua> for UserDataRefMut<'lua, T> { + #[inline] + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { + Self::from_value(value) + } +} + impl<'lua> IntoLua<'lua> for Error { #[inline] fn into_lua(self, _: &'lua Lua) -> Result> { diff --git a/src/lib.rs b/src/lib.rs index 9dac713..1bcad7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -121,6 +121,7 @@ pub use crate::thread::{Thread, ThreadStatus}; pub use crate::types::{Integer, LightUserData, Number, RegistryKey}; pub use crate::userdata::{ AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, + UserDataRef, UserDataRefMut, }; pub use crate::userdata_ext::AnyUserDataExt; pub use crate::userdata_impl::UserDataRegistrar; diff --git a/src/userdata.rs b/src/userdata.rs index bcc735a..e43d52f 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -1,7 +1,8 @@ -use std::any::TypeId; +use std::any::{type_name, TypeId}; use std::cell::{Ref, RefCell, RefMut}; use std::fmt; use std::hash::Hash; +use std::mem; use std::ops::{Deref, DerefMut}; use std::os::raw::{c_char, c_int}; use std::string::String as StdString; @@ -22,7 +23,7 @@ use crate::lua::Lua; use crate::table::{Table, TablePairs}; use crate::types::{Callback, LuaRef, MaybeSend}; use crate::util::{check_stack, get_userdata, take_userdata, StackGuard}; -use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; +use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Value}; #[cfg(feature = "async")] use crate::types::AsyncCallback; @@ -1218,6 +1219,67 @@ impl<'lua> Serialize for AnyUserData<'lua> { } } +/// A wrapper type for an immutably borrowed value from a `AnyUserData`. +/// +/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. +pub struct UserDataRef<'lua, T: 'static>(AnyUserData<'lua>, Ref<'lua, T>); + +impl<'lua, T: 'static> Deref for UserDataRef<'lua, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.1 + } +} + +impl<'lua, T: 'static> UserDataRef<'lua, T> { + pub(crate) fn from_value(value: Value<'lua>) -> Result { + let ud = try_value_to_userdata::(value)?; + // It's safe to lift lifetime of `Ref` to `'lua` as long as we hold AnyUserData to it. + let this = unsafe { mem::transmute(ud.borrow::()?) }; + Ok(UserDataRef(ud, this)) + } +} + +/// A wrapper type for a mutably borrowed value from a `AnyUserData`. +/// +/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. +pub struct UserDataRefMut<'lua, T: 'static>(AnyUserData<'lua>, RefMut<'lua, T>); + +impl<'lua, T: 'static> Deref for UserDataRefMut<'lua, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.1 + } +} + +impl<'lua, T: 'static> DerefMut for UserDataRefMut<'lua, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.1 + } +} + +impl<'lua, T: 'static> UserDataRefMut<'lua, T> { + pub(crate) fn from_value(value: Value<'lua>) -> Result { + let ud = try_value_to_userdata::(value)?; + // It's safe to lift lifetime of `RefMut` to `'lua` as long as we hold AnyUserData to it. + let this = unsafe { mem::transmute(ud.borrow_mut::()?) }; + Ok(UserDataRefMut(ud, this)) + } +} + +fn try_value_to_userdata(value: Value) -> Result { + match value { + Value::UserData(ud) => Ok(ud), + _ => Err(Error::FromLuaConversionError { + from: value.type_name(), + to: "userdata", + message: Some(format!("expected userdata of type {}", type_name::())), + }), + } +} + #[cfg(test)] mod assertions { use super::*; diff --git a/tests/userdata.rs b/tests/userdata.rs index bba4661..718a736 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -13,8 +13,8 @@ use std::{cell::RefCell, rc::Rc}; use std::sync::atomic::{AtomicI64, Ordering}; use mlua::{ - AnyUserData, AnyUserDataExt, Error, ExternalError, FromLua, Function, Lua, MetaMethod, Nil, - Result, String, UserData, UserDataFields, UserDataMethods, Value, + AnyUserData, AnyUserDataExt, Error, ExternalError, Function, Lua, MetaMethod, Nil, Result, + String, UserData, UserDataFields, UserDataMethods, UserDataRef, Value, }; #[test] @@ -96,29 +96,25 @@ fn test_metamethods() -> Result<()> { #[derive(Copy, Clone)] struct MyUserData(i64); - impl<'lua> FromLua<'lua> for MyUserData { - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { - match value { - Value::UserData(ud) => Ok(ud.borrow::()?.clone()), - _ => unreachable!(), - } - } - } - impl UserData for MyUserData { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("get", |_, data, ()| Ok(data.0)); methods.add_meta_function( MetaMethod::Add, - |_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 + rhs.0)), + |_, (lhs, rhs): (UserDataRef, UserDataRef)| { + Ok(MyUserData(lhs.0 + rhs.0)) + }, ); methods.add_meta_function( MetaMethod::Sub, - |_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 - rhs.0)), + |_, (lhs, rhs): (UserDataRef, UserDataRef)| { + Ok(MyUserData(lhs.0 - rhs.0)) + }, + ); + methods.add_meta_function( + MetaMethod::Eq, + |_, (lhs, rhs): (UserDataRef, UserDataRef)| Ok(lhs.0 == rhs.0), ); - methods.add_meta_function(MetaMethod::Eq, |_, (lhs, rhs): (MyUserData, MyUserData)| { - Ok(lhs.0 == rhs.0) - }); methods.add_meta_method(MetaMethod::Index, |_, data, index: String| { if index.to_str()? == "inner" { Ok(data.0) @@ -134,13 +130,14 @@ fn test_metamethods() -> Result<()> { ))] methods.add_meta_method(MetaMethod::Pairs, |lua, data, ()| { use std::iter::FromIterator; - let stateless_iter = lua.create_function(|_, (data, i): (MyUserData, i64)| { - let i = i + 1; - if i <= data.0 { - return Ok(mlua::Variadic::from_iter(vec![i, i])); - } - return Ok(mlua::Variadic::new()); - })?; + let stateless_iter = + lua.create_function(|_, (data, i): (UserDataRef, i64)| { + let i = i + 1; + if i <= data.0 { + return Ok(mlua::Variadic::from_iter(vec![i, i])); + } + return Ok(mlua::Variadic::new()); + })?; Ok((stateless_iter, data.clone(), 0)) }); } @@ -152,7 +149,9 @@ fn test_metamethods() -> Result<()> { globals.set("userdata2", MyUserData(3))?; globals.set("userdata3", MyUserData(3))?; assert_eq!( - lua.load("userdata1 + userdata2").eval::()?.0, + lua.load("userdata1 + userdata2") + .eval::>()? + .0, 10 ); @@ -176,7 +175,12 @@ fn test_metamethods() -> Result<()> { ) .eval::()?; - assert_eq!(lua.load("userdata1 - userdata2").eval::()?.0, 4); + assert_eq!( + lua.load("userdata1 - userdata2") + .eval::>()? + .0, + 4 + ); assert_eq!(lua.load("userdata1:get()").eval::()?, 7); assert_eq!(lua.load("userdata2.inner").eval::()?, 3); assert!(lua.load("userdata2.nonexist_field").eval::<()>().is_err());