Add create_userdata_ref/create_userdata_ref_mut for scope (#206)

New methods would allow creating userdata objects from (mutable) reference
to a UserData of registered type.
This commit is contained in:
Alex Orlenko 2023-02-11 21:58:43 +00:00
parent b790b525c1
commit f52abf919e
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
6 changed files with 333 additions and 140 deletions

View File

@ -121,6 +121,7 @@ pub use crate::types::{Integer, LightUserData, Number, RegistryKey};
pub use crate::userdata::{ pub use crate::userdata::{
AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods,
}; };
pub use crate::userdata_impl::UserDataRegistrar;
pub use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, MultiValue, Nil, Value}; pub use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, MultiValue, Nil, Value};
#[cfg(not(feature = "luau"))] #[cfg(not(feature = "luau"))]

View File

@ -1763,23 +1763,12 @@ impl Lua {
/// Otherwise, the userdata object will have an empty metatable. /// Otherwise, the userdata object will have an empty metatable.
/// ///
/// All userdata instances of the same type `T` shares the same metatable. /// All userdata instances of the same type `T` shares the same metatable.
#[inline]
pub fn create_any_userdata<T>(&self, data: T) -> Result<AnyUserData> pub fn create_any_userdata<T>(&self, data: T) -> Result<AnyUserData>
where where
T: MaybeSend + 'static, T: MaybeSend + 'static,
{ {
unsafe { unsafe { self.make_any_userdata(UserDataCell::new(data)) }
self.make_userdata_with_metatable(UserDataCell::new(data), || {
// Check if userdata/metatable is already registered
let type_id = TypeId::of::<T>();
if let Some(&table_id) = (*self.extra.get()).registered_userdata.get(&type_id) {
return Ok(table_id as Integer);
}
// Create empty metatable
let registry = UserDataRegistrar::new();
self.register_userdata_metatable::<T>(registry)
})
}
} }
/// Registers a custom Rust type in Lua to use in userdata objects. /// Registers a custom Rust type in Lua to use in userdata objects.
@ -1891,12 +1880,10 @@ impl Lua {
/// dropped. `Function` types will error when called, and `AnyUserData` will be typeless. It /// dropped. `Function` types will error when called, and `AnyUserData` will be typeless. It
/// would be impossible to prevent handles to scoped values from escaping anyway, since you /// would be impossible to prevent handles to scoped values from escaping anyway, since you
/// would always be able to smuggle them through Lua state. /// would always be able to smuggle them through Lua state.
pub fn scope<'lua, 'scope, R, F>(&'lua self, f: F) -> Result<R> pub fn scope<'lua, 'scope, R>(
where &'lua self,
'lua: 'scope, f: impl FnOnce(&Scope<'lua, 'scope>) -> Result<R>,
R: 'static, ) -> Result<R> {
F: FnOnce(&Scope<'lua, 'scope>) -> Result<R>,
{
f(&Scope::new(self)) f(&Scope::new(self))
} }
@ -2961,6 +2948,23 @@ impl Lua {
}) })
} }
pub(crate) unsafe fn make_any_userdata<T>(&self, data: UserDataCell<T>) -> Result<AnyUserData>
where
T: 'static,
{
self.make_userdata_with_metatable(data, || {
// Check if userdata/metatable is already registered
let type_id = TypeId::of::<T>();
if let Some(&table_id) = (*self.extra.get()).registered_userdata.get(&type_id) {
return Ok(table_id as Integer);
}
// Create empty metatable
let registry = UserDataRegistrar::new();
self.register_userdata_metatable::<T>(registry)
})
}
unsafe fn make_userdata_with_metatable<T>( unsafe fn make_userdata_with_metatable<T>(
&self, &self,
data: UserDataCell<T>, data: UserDataCell<T>,

View File

@ -12,7 +12,7 @@ pub use crate::{
TableSequence as LuaTableSequence, Thread as LuaThread, ThreadStatus as LuaThreadStatus, TableSequence as LuaTableSequence, Thread as LuaThread, ThreadStatus as LuaThreadStatus,
UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserData as LuaUserData, UserDataFields as LuaUserDataFields,
UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods,
Value as LuaValue, UserDataRegistrar as LuaUserDataRegistrar, Value as LuaValue,
}; };
#[cfg(not(feature = "luau"))] #[cfg(not(feature = "luau"))]

View File

@ -2,8 +2,7 @@ use std::any::Any;
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
use std::os::raw::{c_int, c_void}; use std::os::raw::c_int;
use std::rc::Rc;
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
use serde::Serialize; use serde::Serialize;
@ -66,7 +65,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
where where
A: FromLuaMulti<'callback>, A: FromLuaMulti<'callback>,
R: IntoLuaMulti<'callback>, R: IntoLuaMulti<'callback>,
F: 'scope + Fn(&'callback Lua, A) -> Result<R>, F: Fn(&'callback Lua, A) -> Result<R> + 'scope,
{ {
// Safe, because 'scope must outlive 'callback (due to Self containing 'scope), however the // Safe, because 'scope must outlive 'callback (due to Self containing 'scope), however the
// callback itself must be 'scope lifetime, so the function should not be able to capture // callback itself must be 'scope lifetime, so the function should not be able to capture
@ -99,7 +98,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
where where
A: FromLuaMulti<'callback>, A: FromLuaMulti<'callback>,
R: IntoLuaMulti<'callback>, R: IntoLuaMulti<'callback>,
F: 'scope + FnMut(&'callback Lua, A) -> Result<R>, F: FnMut(&'callback Lua, A) -> Result<R> + 'scope,
{ {
let func = RefCell::new(func); let func = RefCell::new(func);
self.create_function(move |lua, args| { self.create_function(move |lua, args| {
@ -144,7 +143,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
} }
} }
/// Create a Lua userdata object from a custom userdata type. /// Creates a Lua userdata object from a custom userdata type.
/// ///
/// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on
/// scope drop, and does not require that the userdata type be Send (but still requires that the /// scope drop, and does not require that the userdata type be Send (but still requires that the
@ -155,12 +154,18 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
/// [`Lua::scope`]: crate::Lua::scope /// [`Lua::scope`]: crate::Lua::scope
pub fn create_userdata<T>(&self, data: T) -> Result<AnyUserData<'lua>> pub fn create_userdata<T>(&self, data: T) -> Result<AnyUserData<'lua>>
where where
T: 'static + UserData, T: UserData + 'static,
{ {
self.create_userdata_inner(UserDataCell::new(data)) // Safe even though T may not be Send, because the parent Lua cannot be sent to another
// thread while the Scope is alive (or the returned AnyUserData handle even).
unsafe {
let ud = self.lua.make_userdata(UserDataCell::new(data))?;
self.seal_userdata::<T>(&ud)?;
Ok(ud)
}
} }
/// Create a Lua userdata object from a custom serializable userdata type. /// Creates a Lua userdata object from a custom serializable userdata type.
/// ///
/// This is a version of [`Lua::create_ser_userdata`] that creates a userdata which expires on /// This is a version of [`Lua::create_ser_userdata`] that creates a userdata which expires on
/// scope drop, and does not require that the userdata type be Send (but still requires that the /// scope drop, and does not require that the userdata type be Send (but still requires that the
@ -175,60 +180,125 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
#[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] #[cfg_attr(docsrs, doc(cfg(feature = "serialize")))]
pub fn create_ser_userdata<T>(&self, data: T) -> Result<AnyUserData<'lua>> pub fn create_ser_userdata<T>(&self, data: T) -> Result<AnyUserData<'lua>>
where where
T: 'static + UserData + Serialize, T: UserData + Serialize + 'static,
{ {
self.create_userdata_inner(UserDataCell::new_ser(data))
}
fn create_userdata_inner<T>(&self, data: UserDataCell<T>) -> Result<AnyUserData<'lua>>
where
T: 'static + UserData,
{
// Safe even though T may not be Send, because the parent Lua cannot be sent to another
// thread while the Scope is alive (or the returned AnyUserData handle even).
unsafe { unsafe {
let ud = self.lua.make_userdata(data)?; let ud = self.lua.make_userdata(UserDataCell::new_ser(data))?;
self.seal_userdata::<T>(&ud)?;
#[cfg(any(feature = "lua51", feature = "luajit"))]
let newtable = self.lua.create_table()?;
let destructor: DestructorCallback = Box::new(move |ud| {
let state = ud.lua.state();
let _sg = StackGuard::new(state);
assert_stack(state, 2);
// Check that userdata is not destructed (via `take()` call)
if ud.lua.push_userdata_ref(&ud).is_err() {
return vec![];
}
// Clear associated user values
#[cfg(feature = "lua54")]
for i in 1..=USER_VALUE_MAXSLOT {
ffi::lua_pushnil(state);
ffi::lua_setiuservalue(state, -2, i as c_int);
}
#[cfg(any(feature = "lua53", feature = "lua52", feature = "luau"))]
{
ffi::lua_pushnil(state);
ffi::lua_setuservalue(state, -2);
}
#[cfg(any(feature = "lua51", feature = "luajit"))]
{
ud.lua.push_ref(&newtable.0);
ffi::lua_setuservalue(state, -2);
}
vec![Box::new(take_userdata::<UserDataCell<T>>(state))]
});
self.destructors
.borrow_mut()
.push((ud.0.clone(), destructor));
Ok(ud) Ok(ud)
} }
} }
/// Create a Lua userdata object from a custom userdata type. /// Creates a Lua userdata object from a reference to custom userdata type.
///
/// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on
/// scope drop, and does not require that the userdata type be Send. This method takes non-'static
/// reference to the data. See [`Lua::scope`] for more details.
///
/// Userdata created with this method will not be able to be mutated from Lua.
pub fn create_userdata_ref<T>(&self, data: &'scope T) -> Result<AnyUserData<'lua>>
where
T: UserData + 'static,
{
unsafe {
let ud = self.lua.make_userdata(UserDataCell::new_ref(data))?;
self.seal_userdata::<T>(&ud)?;
Ok(ud)
}
}
/// Creates a Lua userdata object from a mutable reference to custom userdata type.
///
/// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on
/// scope drop, and does not require that the userdata type be Send. This method takes non-'static
/// mutable reference to the data. See [`Lua::scope`] for more details.
pub fn create_userdata_ref_mut<T>(&self, data: &'scope mut T) -> Result<AnyUserData<'lua>>
where
T: UserData + 'static,
{
unsafe {
let ud = self.lua.make_userdata(UserDataCell::new_ref_mut(data))?;
self.seal_userdata::<T>(&ud)?;
Ok(ud)
}
}
/// Creates a Lua userdata object from a reference to custom Rust type.
///
/// This is a version of [`Lua::create_any_userdata`] that creates a userdata which expires on
/// scope drop, and does not require that the Rust type be Send. This method takes non-'static
/// reference to the data. See [`Lua::scope`] for more details.
///
/// Userdata created with this method will not be able to be mutated from Lua.
pub fn create_any_userdata_ref<T>(&self, data: &'scope T) -> Result<AnyUserData<'lua>>
where
T: 'static,
{
unsafe {
let ud = self.lua.make_any_userdata(UserDataCell::new_ref(data))?;
self.seal_userdata::<T>(&ud)?;
Ok(ud)
}
}
/// Creates a Lua userdata object from a mutable reference to custom Rust type.
///
/// This is a version of [`Lua::create_any_userdata`] that creates a userdata which expires on
/// scope drop, and does not require that the Rust type be Send. This method takes non-'static
/// mutable reference to the data. See [`Lua::scope`] for more details.
pub fn create_any_userdata_ref_mut<T>(&self, data: &'scope mut T) -> Result<AnyUserData<'lua>>
where
T: 'static,
{
let lua = self.lua;
unsafe {
let ud = lua.make_any_userdata(UserDataCell::new_ref_mut(data))?;
self.seal_userdata::<T>(&ud)?;
Ok(ud)
}
}
/// Shortens the lifetime of a userdata to the lifetime of the scope.
unsafe fn seal_userdata<T: 'static>(&self, ud: &AnyUserData<'lua>) -> Result<()> {
#[cfg(any(feature = "lua51", feature = "luajit"))]
let newtable = self.lua.create_table()?;
let destructor: DestructorCallback = Box::new(move |ud| {
let state = ud.lua.state();
let _sg = StackGuard::new(state);
assert_stack(state, 2);
// Check that userdata is not destructed (via `take()` call)
if ud.lua.push_userdata_ref(&ud).is_err() {
return vec![];
}
// Clear associated user values
#[cfg(feature = "lua54")]
for i in 1..=USER_VALUE_MAXSLOT {
ffi::lua_pushnil(state);
ffi::lua_setiuservalue(state, -2, i as c_int);
}
#[cfg(any(feature = "lua53", feature = "lua52", feature = "luau"))]
{
ffi::lua_pushnil(state);
ffi::lua_setuservalue(state, -2);
}
#[cfg(any(feature = "lua51", feature = "luajit"))]
{
ud.lua.push_ref(&newtable.0);
ffi::lua_setuservalue(state, -2);
}
vec![Box::new(take_userdata::<UserDataCell<T>>(state))]
});
self.destructors
.borrow_mut()
.push((ud.0.clone(), destructor));
Ok(())
}
/// Creates a Lua userdata object from a custom userdata type.
/// ///
/// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on
/// scope drop, and does not require that the userdata type be Send or 'static. See /// scope drop, and does not require that the userdata type be Send or 'static. See
@ -253,10 +323,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
/// [`UserDataMethods`]: crate::UserDataMethods /// [`UserDataMethods`]: crate::UserDataMethods
pub fn create_nonstatic_userdata<T>(&self, data: T) -> Result<AnyUserData<'lua>> pub fn create_nonstatic_userdata<T>(&self, data: T) -> Result<AnyUserData<'lua>>
where where
T: 'scope + UserData, T: UserData + 'scope,
{ {
let data = Rc::new(RefCell::new(data));
// 'callback outliving 'scope is a lie to make the types work out, required due to the // 'callback outliving 'scope is a lie to make the types work out, required due to the
// inability to work with the more correct callback type that is universally quantified over // inability to work with the more correct callback type that is universally quantified over
// 'lua. This is safe though, because `UserData::add_methods` does not get to pick the 'lua // 'lua. This is safe though, because `UserData::add_methods` does not get to pick the 'lua
@ -264,8 +332,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// parameters. // parameters.
fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>( fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>(
scope: &Scope<'lua, 'scope>, scope: &Scope<'lua, 'scope>,
data: Rc<RefCell<T>>, ud_ptr: *const UserDataCell<T>,
ud_ptr: *const c_void,
method: NonStaticMethod<'callback, T>, method: NonStaticMethod<'callback, T>,
) -> Result<Function<'lua>> { ) -> Result<Function<'lua>> {
// On methods that actually receive the userdata, we fake a type check on the passed in // On methods that actually receive the userdata, we fake a type check on the passed in
@ -275,7 +342,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// with a type mismatch, but here without this check would proceed as though you had // with a type mismatch, but here without this check would proceed as though you had
// called the method on the original value (since we otherwise completely ignore the // called the method on the original value (since we otherwise completely ignore the
// first argument). // first argument).
let check_ud_type = move |lua: &'callback Lua, value| { let check_ud_type = move |lua: &Lua, value| -> Result<&UserDataCell<T>> {
if let Some(Value::UserData(ud)) = value { if let Some(Value::UserData(ud)) = value {
let state = lua.state(); let state = lua.state();
unsafe { unsafe {
@ -283,7 +350,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
check_stack(state, 2)?; check_stack(state, 2)?;
lua.push_userdata_ref(&ud.0)?; lua.push_userdata_ref(&ud.0)?;
if get_userdata(state, -1) as *const _ == ud_ptr { if get_userdata(state, -1) as *const _ == ud_ptr {
return Ok(()); return Ok(&*ud_ptr);
} }
} }
}; };
@ -293,8 +360,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
match method { match method {
NonStaticMethod::Method(method) => { NonStaticMethod::Method(method) => {
let f = Box::new(move |lua, mut args: MultiValue<'callback>| { let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
check_ud_type(lua, args.pop_front())?; let data = check_ud_type(lua, args.pop_front())?;
let data = data.try_borrow().map_err(|_| Error::UserDataBorrowError)?; let data = data.try_borrow()?;
method(lua, &*data, args) method(lua, &*data, args)
}); });
unsafe { scope.create_callback(f) } unsafe { scope.create_callback(f) }
@ -302,13 +369,11 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
NonStaticMethod::MethodMut(method) => { NonStaticMethod::MethodMut(method) => {
let method = RefCell::new(method); let method = RefCell::new(method);
let f = Box::new(move |lua, mut args: MultiValue<'callback>| { let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
check_ud_type(lua, args.pop_front())?; let data = check_ud_type(lua, args.pop_front())?;
let mut method = method let mut method = method
.try_borrow_mut() .try_borrow_mut()
.map_err(|_| Error::RecursiveMutCallback)?; .map_err(|_| Error::RecursiveMutCallback)?;
let mut data = data let mut data = data.try_borrow_mut()?;
.try_borrow_mut()
.map_err(|_| Error::UserDataBorrowMutError)?;
(*method)(lua, &mut *data, args) (*method)(lua, &mut *data, args)
}); });
unsafe { scope.create_callback(f) } unsafe { scope.create_callback(f) }
@ -342,8 +407,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
#[cfg(not(feature = "luau"))] #[cfg(not(feature = "luau"))]
#[allow(clippy::let_and_return)] #[allow(clippy::let_and_return)]
let ud_ptr = protect_lua!(state, 0, 1, |state| { let ud_ptr = protect_lua!(state, 0, 1, |state| {
let ud = let ud = ffi::lua_newuserdata(state, mem::size_of::<UserDataCell<T>>());
ffi::lua_newuserdata(state, mem::size_of::<UserDataCell<Rc<RefCell<T>>>>());
// Set empty environment for Lua 5.1 // Set empty environment for Lua 5.1
#[cfg(any(feature = "lua51", feature = "luajit"))] #[cfg(any(feature = "lua51", feature = "luajit"))]
@ -352,16 +416,12 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
ffi::lua_setuservalue(state, -2); ffi::lua_setuservalue(state, -2);
} }
ud ud as *const UserDataCell<T>
})?; })?;
#[cfg(feature = "luau")] #[cfg(feature = "luau")]
let ud_ptr = { let ud_ptr = {
crate::util::push_userdata::<UserDataCell<Rc<RefCell<T>>>>( crate::util::push_userdata(state, UserDataCell::new(data), true)?;
state, ffi::lua_touserdata(state, -1) as *const UserDataCell<T>
UserDataCell::new(data.clone()),
true,
)?;
ffi::lua_touserdata(state, -1)
}; };
// Prepare metatable, add meta methods first and then meta fields // Prepare metatable, add meta methods first and then meta fields
@ -369,8 +429,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
push_table(state, 0, meta_methods_nrec as c_int, true)?; push_table(state, 0, meta_methods_nrec as c_int, true)?;
for (k, m) in ud_methods.meta_methods { for (k, m) in ud_methods.meta_methods {
let data = data.clone(); lua.push_value(Value::Function(wrap_method(self, ud_ptr, m)?))?;
lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?;
rawset_field(state, -2, MetaMethod::validate(&k)?)?; rawset_field(state, -2, MetaMethod::validate(&k)?)?;
} }
for (k, f) in ud_fields.meta_fields { for (k, f) in ud_fields.meta_fields {
@ -384,8 +443,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
if field_getters_nrec > 0 { if field_getters_nrec > 0 {
push_table(state, 0, field_getters_nrec as c_int, true)?; push_table(state, 0, field_getters_nrec as c_int, true)?;
for (k, m) in ud_fields.field_getters { for (k, m) in ud_fields.field_getters {
let data = data.clone(); lua.push_value(Value::Function(wrap_method(self, ud_ptr, m)?))?;
lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?;
rawset_field(state, -2, &k)?; rawset_field(state, -2, &k)?;
} }
field_getters_index = Some(ffi::lua_absindex(state, -1)); field_getters_index = Some(ffi::lua_absindex(state, -1));
@ -396,8 +454,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
if field_setters_nrec > 0 { if field_setters_nrec > 0 {
push_table(state, 0, field_setters_nrec as c_int, true)?; push_table(state, 0, field_setters_nrec as c_int, true)?;
for (k, m) in ud_fields.field_setters { for (k, m) in ud_fields.field_setters {
let data = data.clone(); lua.push_value(Value::Function(wrap_method(self, ud_ptr, m)?))?;
lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?;
rawset_field(state, -2, &k)?; rawset_field(state, -2, &k)?;
} }
field_setters_index = Some(ffi::lua_absindex(state, -1)); field_setters_index = Some(ffi::lua_absindex(state, -1));
@ -409,14 +466,13 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// Create table used for methods lookup // Create table used for methods lookup
push_table(state, 0, methods_nrec as c_int, true)?; push_table(state, 0, methods_nrec as c_int, true)?;
for (k, m) in ud_methods.methods { for (k, m) in ud_methods.methods {
let data = data.clone(); lua.push_value(Value::Function(wrap_method(self, ud_ptr, m)?))?;
lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?;
rawset_field(state, -2, &k)?; rawset_field(state, -2, &k)?;
} }
methods_index = Some(ffi::lua_absindex(state, -1)); methods_index = Some(ffi::lua_absindex(state, -1));
} }
init_userdata_metatable::<UserDataCell<Rc<RefCell<T>>>>( init_userdata_metatable::<UserDataCell<T>>(
state, state,
metatable_index, metatable_index,
field_getters_index, field_getters_index,
@ -478,8 +534,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
mem::transmute(f) mem::transmute(f)
} }
let ud = Box::new(seal(take_userdata::<UserDataCell<Rc<RefCell<T>>>>(state))); let ud = take_userdata::<UserDataCell<T>>(state);
vec![ud] vec![Box::new(seal(ud))]
}); });
self.destructors self.destructors
.borrow_mut() .borrow_mut()

View File

@ -578,12 +578,22 @@ pub trait UserData: Sized {
} }
// Wraps UserData in a way to always implement `serde::Serialize` trait. // Wraps UserData in a way to always implement `serde::Serialize` trait.
pub(crate) struct UserDataCell<T>(RefCell<UserDataWrapped<T>>); pub(crate) struct UserDataCell<T>(RefCell<UserDataVariant<T>>);
impl<T> UserDataCell<T> { impl<T> UserDataCell<T> {
#[inline] #[inline]
pub(crate) fn new(data: T) -> Self { pub(crate) fn new(data: T) -> Self {
UserDataCell(RefCell::new(UserDataWrapped::new(data))) UserDataCell(RefCell::new(UserDataVariant::new(data)))
}
#[inline]
pub(crate) fn new_ref(data: &T) -> Self {
UserDataCell(RefCell::new(UserDataVariant::new_ref(data)))
}
#[inline]
pub(crate) fn new_ref_mut(data: &mut T) -> Self {
UserDataCell(RefCell::new(UserDataVariant::new_ref_mut(data)))
} }
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
@ -592,7 +602,7 @@ impl<T> UserDataCell<T> {
where where
T: Serialize + 'static, T: Serialize + 'static,
{ {
UserDataCell(RefCell::new(UserDataWrapped::new_ser(data))) UserDataCell(RefCell::new(UserDataVariant::new_ser(data)))
} }
// Immutably borrows the wrapped value. // Immutably borrows the wrapped value.
@ -609,27 +619,42 @@ impl<T> UserDataCell<T> {
pub(crate) fn try_borrow_mut(&self) -> Result<RefMut<T>> { pub(crate) fn try_borrow_mut(&self) -> Result<RefMut<T>> {
self.0 self.0
.try_borrow_mut() .try_borrow_mut()
.map(|r| RefMut::map(r, |r| r.deref_mut()))
.map_err(|_| Error::UserDataBorrowMutError) .map_err(|_| Error::UserDataBorrowMutError)
.and_then(|r| {
RefMut::filter_map(r, |r| r.try_deref_mut().ok())
.map_err(|_| Error::UserDataBorrowMutError)
})
} }
// Consumes this `UserDataCell`, returning the wrapped value. // Consumes this `UserDataCell`, returning the wrapped value.
#[inline] #[inline]
unsafe fn into_inner(self) -> T { fn into_inner(self) -> Result<T> {
self.0.into_inner().into_inner() self.0.into_inner().into_inner()
} }
} }
pub(crate) enum UserDataWrapped<T> { pub(crate) enum UserDataVariant<T> {
Default(Box<T>), Default(Box<T>),
Ref(*const T),
RefMut(*mut T),
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
Serializable(Box<dyn erased_serde::Serialize>), Serializable(Box<dyn erased_serde::Serialize>),
} }
impl<T> UserDataWrapped<T> { impl<T> UserDataVariant<T> {
#[inline] #[inline]
fn new(data: T) -> Self { fn new(data: T) -> Self {
UserDataWrapped::Default(Box::new(data)) UserDataVariant::Default(Box::new(data))
}
#[inline]
fn new_ref(data: &T) -> Self {
UserDataVariant::Ref(data)
}
#[inline]
fn new_ref_mut(data: &mut T) -> Self {
UserDataVariant::RefMut(data)
} }
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
@ -638,26 +663,42 @@ impl<T> UserDataWrapped<T> {
where where
T: Serialize + 'static, T: Serialize + 'static,
{ {
UserDataWrapped::Serializable(Box::new(data)) UserDataVariant::Serializable(Box::new(data))
} }
#[inline] #[inline]
unsafe fn into_inner(self) -> T { fn try_deref_mut(&mut self) -> Result<&mut T> {
match self { match self {
Self::Default(data) => *data, Self::Default(data) => Ok(data.deref_mut()),
Self::Ref(_) => Err(Error::UserDataBorrowMutError),
Self::RefMut(data) => unsafe { Ok(&mut **data) },
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
Self::Serializable(data) => *Box::from_raw(Box::into_raw(data) as *mut T), Self::Serializable(data) => unsafe { Ok(&mut *(data.as_mut() as *mut _ as *mut T)) },
}
}
#[inline]
fn into_inner(self) -> Result<T> {
match self {
Self::Default(data) => Ok(*data),
Self::Ref(_) | Self::RefMut(_) => Err(Error::UserDataTypeMismatch),
#[cfg(feature = "serialize")]
Self::Serializable(data) => unsafe {
Ok(*Box::from_raw(Box::into_raw(data) as *mut T))
},
} }
} }
} }
impl<T> Deref for UserDataWrapped<T> { impl<T> Deref for UserDataVariant<T> {
type Target = T; type Target = T;
#[inline] #[inline]
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
match self { match self {
Self::Default(data) => data, Self::Default(data) => data,
Self::Ref(data) => unsafe { &**data },
Self::RefMut(data) => unsafe { &**data },
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
Self::Serializable(data) => unsafe { Self::Serializable(data) => unsafe {
&*(data.as_ref() as *const _ as *const Self::Target) &*(data.as_ref() as *const _ as *const Self::Target)
@ -666,19 +707,6 @@ impl<T> Deref for UserDataWrapped<T> {
} }
} }
impl<T> DerefMut for UserDataWrapped<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Self::Default(data) => data,
#[cfg(feature = "serialize")]
Self::Serializable(data) => unsafe {
&mut *(data.as_mut() as *mut _ as *mut Self::Target)
},
}
}
}
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
struct UserDataSerializeError; struct UserDataSerializeError;
@ -771,7 +799,7 @@ impl<'lua> AnyUserData<'lua> {
Some(type_id) if type_id == TypeId::of::<T>() => { Some(type_id) if type_id == TypeId::of::<T>() => {
// Try to borrow userdata exclusively // Try to borrow userdata exclusively
let _ = (*get_userdata::<UserDataCell<T>>(state, -1)).try_borrow_mut()?; let _ = (*get_userdata::<UserDataCell<T>>(state, -1)).try_borrow_mut()?;
Ok(take_userdata::<UserDataCell<T>>(state).into_inner()) take_userdata::<UserDataCell<T>>(state).into_inner()
} }
_ => Err(Error::UserDataTypeMismatch), _ => Err(Error::UserDataTypeMismatch),
} }
@ -1043,8 +1071,8 @@ impl<'lua> AnyUserData<'lua> {
let ud = &*get_userdata::<UserDataCell<()>>(state, -1); let ud = &*get_userdata::<UserDataCell<()>>(state, -1);
match &*ud.0.try_borrow().map_err(|_| Error::UserDataBorrowError)? { match &*ud.0.try_borrow().map_err(|_| Error::UserDataBorrowError)? {
UserDataWrapped::Default(_) => Result::Ok(false), UserDataVariant::Serializable(_) => Result::Ok(true),
UserDataWrapped::Serializable(_) => Result::Ok(true), _ => Result::Ok(false),
} }
}; };
is_serializable().unwrap_or(false) is_serializable().unwrap_or(false)
@ -1184,8 +1212,8 @@ impl<'lua> Serialize for AnyUserData<'lua> {
.map_err(|_| ser::Error::custom(Error::UserDataBorrowError))? .map_err(|_| ser::Error::custom(Error::UserDataBorrowError))?
}; };
match &*data { match &*data {
UserDataWrapped::Default(_) => UserDataSerializeError.serialize(serializer), UserDataVariant::Serializable(ser) => ser.serialize(serializer),
UserDataWrapped::Serializable(ser) => ser.serialize(serializer), _ => UserDataSerializeError.serialize(serializer),
} }
} }
} }

View File

@ -356,3 +356,107 @@ fn test_scope_nonstatic_userdata_drop() -> Result<()> {
Ok(()) Ok(())
} }
#[test]
fn test_scope_userdata_ref() -> Result<()> {
let lua = Lua::new();
struct MyUserData(Cell<i64>);
impl UserData for MyUserData {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("inc", |_, data, ()| {
data.0.set(data.0.get() + 1);
Ok(())
});
methods.add_method("dec", |_, data, ()| {
data.0.set(data.0.get() - 1);
Ok(())
});
}
}
let data = MyUserData(Cell::new(1));
lua.scope(|scope| {
let ud = scope.create_userdata_ref(&data)?;
modify_userdata(&lua, ud)
})?;
assert_eq!(data.0.get(), 2);
Ok(())
}
#[test]
fn test_scope_userdata_ref_mut() -> Result<()> {
let lua = Lua::new();
struct MyUserData(i64);
impl UserData for MyUserData {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method_mut("inc", |_, data, ()| {
data.0 += 1;
Ok(())
});
methods.add_method_mut("dec", |_, data, ()| {
data.0 -= 1;
Ok(())
});
}
}
let mut data = MyUserData(1);
lua.scope(|scope| {
let ud = scope.create_userdata_ref_mut(&mut data)?;
modify_userdata(&lua, ud)
})?;
assert_eq!(data.0, 2);
Ok(())
}
#[test]
fn test_scope_any_userdata_ref() -> Result<()> {
let lua = Lua::new();
lua.register_userdata_type::<Cell<i64>>(|reg| {
reg.add_method("inc", |_, data, ()| {
data.set(data.get() + 1);
Ok(())
});
reg.add_method("dec", |_, data, ()| {
data.set(data.get() - 1);
Ok(())
});
})?;
let data = Cell::new(1i64);
lua.scope(|scope| {
let ud = scope.create_any_userdata_ref(&data)?;
modify_userdata(&lua, ud)
})?;
assert_eq!(data.get(), 2);
Ok(())
}
fn modify_userdata(lua: &Lua, ud: AnyUserData) -> Result<()> {
let f: Function = lua
.load(
r#"
function(u)
u:inc()
u:dec()
u:inc()
end
"#,
)
.eval()?;
f.call(ud)?;
Ok(())
}