impl UserData for Rc<T> and Arc<T> where T: UserData

This commit is contained in:
Alex Orlenko 2023-06-21 01:30:09 +01:00
parent aeacf6cacc
commit b05698d55b
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
3 changed files with 90 additions and 12 deletions

View File

@ -130,7 +130,7 @@ pub enum Error {
///
/// [`AnyUserData`]: crate::AnyUserData
UserDataDestructed,
/// An [`AnyUserData`] immutable borrow failed because it is already borrowed mutably.
/// An [`AnyUserData`] immutable borrow failed.
///
/// This error can occur when a method on a [`UserData`] type calls back into Lua, which then
/// tries to call a method on the same [`UserData`] type. Consider restructuring your API to
@ -139,7 +139,7 @@ pub enum Error {
/// [`AnyUserData`]: crate::AnyUserData
/// [`UserData`]: crate::UserData
UserDataBorrowError,
/// An [`AnyUserData`] mutable borrow failed because it is already borrowed.
/// An [`AnyUserData`] mutable borrow failed.
///
/// This error can occur when a method on a [`UserData`] type calls back into Lua, which then
/// tries to call a method on the same [`UserData`] type. Consider restructuring your API to
@ -270,8 +270,8 @@ impl fmt::Display for Error {
Error::CoroutineInactive => write!(fmt, "cannot resume inactive coroutine"),
Error::UserDataTypeMismatch => write!(fmt, "userdata is not expected type"),
Error::UserDataDestructed => write!(fmt, "userdata has been destructed"),
Error::UserDataBorrowError => write!(fmt, "userdata already mutably borrowed"),
Error::UserDataBorrowMutError => write!(fmt, "userdata already borrowed"),
Error::UserDataBorrowError => write!(fmt, "error borrowing userdata"),
Error::UserDataBorrowMutError => write!(fmt, "error mutably borrowing userdata"),
Error::MetaMethodRestricted(ref method) => write!(fmt, "metamethod {method} is restricted"),
Error::MetaMethodTypeError { ref method, type_name, ref message } => {
write!(fmt, "metamethod {method} has unsupported type {type_name}")?;

View File

@ -96,11 +96,20 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> {
call(&ud)
},
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<T>>() => unsafe {
let ud = try_self_arg!(get_userdata_ref::<Rc<T>>(ref_thread, index));
call(&ud)
},
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => unsafe {
let ud = try_self_arg!(get_userdata_ref::<Rc<RefCell<T>>>(ref_thread, index));
let ud = try_self_arg!(ud.try_borrow(), Error::UserDataBorrowError);
call(&ud)
},
Some(id) if id == TypeId::of::<Arc<T>>() => unsafe {
let ud = try_self_arg!(get_userdata_ref::<Arc<T>>(ref_thread, index));
call(&ud)
},
Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => unsafe {
let ud = try_self_arg!(get_userdata_ref::<Arc<Mutex<T>>>(ref_thread, index));
let ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowError);
@ -169,11 +178,14 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> {
call(&mut ud)
},
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<T>>() => Err(Error::UserDataBorrowMutError),
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => unsafe {
let ud = try_self_arg!(get_userdata_mut::<Rc<RefCell<T>>>(ref_thread, index));
let mut ud = try_self_arg!(ud.try_borrow_mut(), Error::UserDataBorrowMutError);
call(&mut ud)
},
Some(id) if id == TypeId::of::<Arc<T>>() => Err(Error::UserDataBorrowMutError),
Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => unsafe {
let ud = try_self_arg!(get_userdata_mut::<Arc<Mutex<T>>>(ref_thread, index));
let mut ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowMutError);
@ -244,6 +256,13 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> {
method(lua, ud, args).await?.into_lua_multi(lua)
},
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<T>>() => unsafe {
let ud = try_self_arg!(get_userdata_ref::<Rc<T>>(ref_thread, index));
let ud = std::mem::transmute::<&T, &T>(&ud);
let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?;
method(lua, ud, args).await?.into_lua_multi(lua)
},
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => unsafe {
let ud =
try_self_arg!(get_userdata_ref::<Rc<RefCell<T>>>(ref_thread, index));
@ -252,6 +271,12 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> {
let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?;
method(lua, ud, args).await?.into_lua_multi(lua)
},
Some(id) if id == TypeId::of::<Arc<T>>() => unsafe {
let ud = try_self_arg!(get_userdata_ref::<Arc<T>>(ref_thread, index));
let ud = std::mem::transmute::<&T, &T>(&ud);
let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?;
method(lua, ud, args).await?.into_lua_multi(lua)
},
Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => unsafe {
let ud =
try_self_arg!(get_userdata_ref::<Arc<Mutex<T>>>(ref_thread, index));
@ -333,6 +358,10 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> {
method(lua, ud, args).await?.into_lua_multi(lua)
},
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => {
Err(Error::UserDataBorrowMutError)
}
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => unsafe {
let ud =
try_self_arg!(get_userdata_mut::<Rc<RefCell<T>>>(ref_thread, index));
@ -342,6 +371,8 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> {
let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?;
method(lua, ud, args).await?.into_lua_multi(lua)
},
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Arc<T>>() => Err(Error::UserDataBorrowMutError),
Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => unsafe {
let ud =
try_self_arg!(get_userdata_mut::<Arc<Mutex<T>>>(ref_thread, index));
@ -767,8 +798,12 @@ macro_rules! lua_userdata_impl {
};
}
#[cfg(not(feature = "send"))]
lua_userdata_impl!(Rc<T>);
#[cfg(not(feature = "send"))]
lua_userdata_impl!(Rc<RefCell<T>>);
lua_userdata_impl!(Arc<T>);
lua_userdata_impl!(Arc<Mutex<T>>);
lua_userdata_impl!(Arc<RwLock<T>>);
#[cfg(feature = "parking_lot")]

View File

@ -625,10 +625,33 @@ fn test_userdata_wrapped() -> Result<()> {
let lua = Lua::new();
let globals = lua.globals();
// Rc<T>
#[cfg(not(feature = "send"))]
{
let ud1 = Rc::new(RefCell::new(MyUserData(1)));
globals.set("rc_refcell_ud", ud1.clone())?;
let ud = Rc::new(MyUserData(1));
globals.set("rc_ud", ud.clone())?;
lua.load(
r#"
assert(rc_ud.static == "constant")
local ok, err = pcall(function() rc_ud.data = 2 end)
assert(
tostring(err):sub(1, 32) == "error mutably borrowing userdata",
"expected error mutably borrowing userdata, got " .. tostring(err)
)
assert(rc_ud.data == 1)
"#,
)
.exec()?;
globals.set("rc_ud", Nil)?;
lua.gc_collect()?;
assert_eq!(Rc::strong_count(&ud), 1);
}
// Rc<RefCell<T>>
#[cfg(not(feature = "send"))]
{
let ud = Rc::new(RefCell::new(MyUserData(1)));
globals.set("rc_refcell_ud", ud.clone())?;
lua.load(
r#"
assert(rc_refcell_ud.static == "constant")
@ -637,12 +660,32 @@ fn test_userdata_wrapped() -> Result<()> {
"#,
)
.exec()?;
assert_eq!(ud1.borrow().0, 2);
assert_eq!(ud.borrow().0, 2);
globals.set("rc_refcell_ud", Nil)?;
lua.gc_collect()?;
assert_eq!(Rc::strong_count(&ud1), 1);
assert_eq!(Rc::strong_count(&ud), 1);
}
// Arc<T>
let ud1 = Arc::new(MyUserData(2));
globals.set("arc_ud", ud1.clone())?;
lua.load(
r#"
assert(arc_ud.static == "constant")
local ok, err = pcall(function() arc_ud.data = 3 end)
assert(
tostring(err):sub(1, 32) == "error mutably borrowing userdata",
"expected error mutably borrowing userdata, got " .. tostring(err)
)
assert(arc_ud.data == 2)
"#,
)
.exec()?;
globals.set("arc_ud", Nil)?;
lua.gc_collect()?;
assert_eq!(Arc::strong_count(&ud1), 1);
// Arc<Mutex<T>>
let ud2 = Arc::new(Mutex::new(MyUserData(2)));
globals.set("arc_mutex_ud", ud2.clone())?;
lua.load(
@ -657,7 +700,11 @@ fn test_userdata_wrapped() -> Result<()> {
assert_eq!(ud2.lock().unwrap().0, 3);
#[cfg(feature = "parking_lot")]
assert_eq!(ud2.lock().0, 3);
globals.set("arc_mutex_ud", Nil)?;
lua.gc_collect()?;
assert_eq!(Arc::strong_count(&ud2), 1);
// Arc<RwLock<T>>
let ud3 = Arc::new(RwLock::new(MyUserData(3)));
globals.set("arc_rwlock_ud", ud3.clone())?;
lua.load(
@ -672,12 +719,8 @@ fn test_userdata_wrapped() -> Result<()> {
assert_eq!(ud3.read().unwrap().0, 4);
#[cfg(feature = "parking_lot")]
assert_eq!(ud3.read().0, 4);
// Test drop
globals.set("arc_mutex_ud", Nil)?;
globals.set("arc_rwlock_ud", Nil)?;
lua.gc_collect()?;
assert_eq!(Arc::strong_count(&ud2), 1);
assert_eq!(Arc::strong_count(&ud3), 1);
Ok(())