Implement `UserData` for Rc<RefCell>/Arc<Mutex>/Arc<RwLock> wrappers

This commit is contained in:
Alex Orlenko 2021-05-31 23:33:44 +01:00
parent bae424672a
commit a944f4ad6f
4 changed files with 370 additions and 85 deletions

View File

@ -6,7 +6,7 @@ use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
use std::panic::resume_unwind; use std::panic::resume_unwind;
use std::sync::{Arc, Mutex, MutexGuard, Weak}; use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
use std::{mem, ptr, str}; use std::{mem, ptr, str};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -32,6 +32,9 @@ use crate::util::{
}; };
use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value};
#[cfg(not(feature = "send"))]
use std::rc::Rc;
#[cfg(feature = "async")] #[cfg(feature = "async")]
use { use {
crate::types::AsyncCallback, crate::types::AsyncCallback,
@ -1607,6 +1610,11 @@ impl Lua {
self.push_value(f(self)?)?; self.push_value(f(self)?)?;
ffi::safe::lua_rawsetfield(self.state, -2, k.validate()?.name())?; ffi::safe::lua_rawsetfield(self.state, -2, k.validate()?.name())?;
} }
// Add special `__mlua_type_id` field
let type_id_ptr =
ffi::safe::lua_newuserdata(self.state, mem::size_of::<TypeId>())? as *mut TypeId;
ptr::write(type_id_ptr, type_id);
ffi::safe::lua_rawsetfield(self.state, -2, "__mlua_type_id")?;
let metatable_index = ffi::lua_absindex(self.state, -1); let metatable_index = ffi::lua_absindex(self.state, -1);
let mut extra_tables_count = 0; let mut extra_tables_count = 0;
@ -1690,7 +1698,7 @@ impl Lua {
// Pushes a LuaRef value onto the stack, checking that it's a registered // Pushes a LuaRef value onto the stack, checking that it's a registered
// and not destructed UserData. // and not destructed UserData.
// Uses 3 stack spaces, does not call checkstack. // Uses 3 stack spaces, does not call checkstack.
pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef) -> Result<()> { pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef, with_mt: bool) -> Result<()> {
self.push_ref(lref); self.push_ref(lref);
if ffi::lua_getmetatable(self.state, -1) == 0 { if ffi::lua_getmetatable(self.state, -1) == 0 {
return Err(Error::UserDataTypeMismatch); return Err(Error::UserDataTypeMismatch);
@ -1699,7 +1707,9 @@ impl Lua {
let ptr = ffi::lua_topointer(self.state, -1); let ptr = ffi::lua_topointer(self.state, -1);
let extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); let extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
if extra.registered_userdata_mt.contains(&(ptr as isize)) { if extra.registered_userdata_mt.contains(&(ptr as isize)) {
ffi::lua_pop(self.state, 1); if !with_mt {
ffi::lua_pop(self.state, 1);
}
return Ok(()); return Ok(());
} }
// Maybe userdata was destructed? // Maybe userdata was destructed?
@ -2488,6 +2498,21 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet
self.meta_methods self.meta_methods
.push((meta.into(), Self::box_function_mut(function))); .push((meta.into(), Self::box_function_mut(function)));
} }
// Below are internal methods used in generated code
fn add_callback(&mut self, name: Vec<u8>, callback: Callback<'lua, 'static>) {
self.methods.push((name, callback));
}
#[cfg(feature = "async")]
fn add_async_callback(&mut self, name: Vec<u8>, callback: AsyncCallback<'lua, 'static>) {
self.async_methods.push((name, callback));
}
fn add_meta_callback(&mut self, meta: MetaMethod, callback: Callback<'lua, 'static>) {
self.meta_methods.push((meta, callback));
}
} }
impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> {
@ -2500,8 +2525,29 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> {
Box::new(move |lua, mut args| { Box::new(move |lua, mut args| {
if let Some(front) = args.pop_front() { if let Some(front) = args.pop_front() {
let userdata = AnyUserData::from_lua(front, lua)?; let userdata = AnyUserData::from_lua(front, lua)?;
let userdata = userdata.borrow::<T>()?; match userdata.type_id()? {
method(lua, &userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) id if id == TypeId::of::<T>() => {
let ud = userdata.borrow::<T>()?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
#[cfg(not(feature = "send"))]
id if id == TypeId::of::<Rc<RefCell<T>>>() => {
let ud = userdata.borrow::<Rc<RefCell<T>>>()?;
let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<Mutex<T>>>() => {
let ud = userdata.borrow::<Arc<Mutex<T>>>()?;
let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<RwLock<T>>>() => {
let ud = userdata.borrow::<Arc<RwLock<T>>>()?;
let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
_ => Err(Error::UserDataTypeMismatch),
}
} else { } else {
Err(Error::FromLuaConversionError { Err(Error::FromLuaConversionError {
from: "missing argument", from: "missing argument",
@ -2522,11 +2568,34 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> {
Box::new(move |lua, mut args| { Box::new(move |lua, mut args| {
if let Some(front) = args.pop_front() { if let Some(front) = args.pop_front() {
let userdata = AnyUserData::from_lua(front, lua)?; let userdata = AnyUserData::from_lua(front, lua)?;
let mut userdata = userdata.borrow_mut::<T>()?;
let mut method = method let mut method = method
.try_borrow_mut() .try_borrow_mut()
.map_err(|_| Error::RecursiveMutCallback)?; .map_err(|_| Error::RecursiveMutCallback)?;
(&mut *method)(lua, &mut userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) match userdata.type_id()? {
id if id == TypeId::of::<T>() => {
let mut ud = userdata.borrow_mut::<T>()?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
#[cfg(not(feature = "send"))]
id if id == TypeId::of::<Rc<RefCell<T>>>() => {
let ud = userdata.borrow::<Rc<RefCell<T>>>()?;
let mut ud = ud
.try_borrow_mut()
.map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<Mutex<T>>>() => {
let ud = userdata.borrow::<Arc<Mutex<T>>>()?;
let mut ud = ud.try_lock().map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<RwLock<T>>>() => {
let ud = userdata.borrow::<Arc<RwLock<T>>>()?;
let mut ud = ud.try_write().map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
_ => Err(Error::UserDataTypeMismatch),
}
} else { } else {
Err(Error::FromLuaConversionError { Err(Error::FromLuaConversionError {
from: "missing argument", from: "missing argument",
@ -2709,4 +2778,51 @@ impl<'lua, T: 'static + UserData> UserDataFields<'lua, T> for StaticUserDataFiel
}), }),
)); ));
} }
// Below are internal methods
fn add_field_getter(&mut self, name: Vec<u8>, callback: Callback<'lua, 'static>) {
self.field_getters.push((name, callback));
}
fn add_field_setter(&mut self, name: Vec<u8>, callback: Callback<'lua, 'static>) {
self.field_setters.push((name, callback));
}
} }
macro_rules! lua_userdata_impl {
($type:ty) => {
impl<T: 'static + UserData> UserData for $type {
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
let mut orig_fields = StaticUserDataFields::default();
T::add_fields(&mut orig_fields);
for (name, callback) in orig_fields.field_getters {
fields.add_field_getter(name, callback);
}
for (name, callback) in orig_fields.field_setters {
fields.add_field_setter(name, callback);
}
}
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
let mut orig_methods = StaticUserDataMethods::default();
T::add_methods(&mut orig_methods);
for (name, callback) in orig_methods.methods {
methods.add_callback(name, callback);
}
#[cfg(feature = "async")]
for (name, callback) in orig_methods.async_methods {
methods.add_async_callback(name, callback);
}
for (meta, callback) in orig_methods.meta_methods {
methods.add_meta_callback(meta, callback);
}
}
}
};
}
#[cfg(not(feature = "send"))]
lua_userdata_impl!(Rc<RefCell<T>>);
lua_userdata_impl!(Arc<Mutex<T>>);
lua_userdata_impl!(Arc<RwLock<T>>);

View File

@ -1,8 +1,9 @@
use std::any::Any; use std::any::Any;
use std::cell::{Cell, Ref, RefCell, RefMut}; use std::cell::{Cell, RefCell};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
use std::os::raw::{c_int, c_void}; use std::os::raw::{c_int, c_void};
use std::rc::Rc;
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
use serde::Serialize; use serde::Serialize;
@ -238,7 +239,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
where where
T: 'scope + UserData, T: 'scope + UserData,
{ {
let data = UserDataCell::new_arc(data); let data = Rc::new(RefCell::new(data));
// 'callback outliving 'scope is a lie to make the types work out, required due to the // 'callback outliving 'scope is a lie to make the types work out, required due to the
// inability to work with the more correct callback type that is universally quantified over // inability to work with the more correct callback type that is universally quantified over
@ -247,7 +248,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// parameters. // parameters.
fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>( fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>(
scope: &Scope<'lua, 'scope>, scope: &Scope<'lua, 'scope>,
data: UserDataCell<T>, data: Rc<RefCell<T>>,
data_ptr: *mut c_void, data_ptr: *mut c_void,
method: NonStaticMethod<'callback, T>, method: NonStaticMethod<'callback, T>,
) -> Result<Function<'lua>> { ) -> Result<Function<'lua>> {
@ -263,7 +264,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
unsafe { unsafe {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?; check_stack(lua.state, 3)?;
lua.push_userdata_ref(&ud.0)?; lua.push_userdata_ref(&ud.0, false)?;
if get_userdata(lua.state, -1) == data_ptr { if get_userdata(lua.state, -1) == data_ptr {
return Ok(()); return Ok(());
} }
@ -276,10 +277,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
NonStaticMethod::Method(method) => { NonStaticMethod::Method(method) => {
let f = Box::new(move |lua, mut args: MultiValue<'callback>| { let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
check_ud_type(lua, args.pop_front())?; check_ud_type(lua, args.pop_front())?;
let data = data let data = data.try_borrow().map_err(|_| Error::UserDataBorrowError)?;
.try_borrow()
.map(|cell| Ref::map(cell, AsRef::as_ref))
.map_err(|_| Error::UserDataBorrowError)?;
method(lua, &*data, args) method(lua, &*data, args)
}); });
unsafe { scope.create_callback(f) } unsafe { scope.create_callback(f) }
@ -293,7 +291,6 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
.map_err(|_| Error::RecursiveMutCallback)?; .map_err(|_| Error::RecursiveMutCallback)?;
let mut data = data let mut data = data
.try_borrow_mut() .try_borrow_mut()
.map(|cell| RefMut::map(cell, AsMut::as_mut))
.map_err(|_| Error::UserDataBorrowMutError)?; .map_err(|_| Error::UserDataBorrowMutError)?;
(&mut *method)(lua, &mut *data, args) (&mut *method)(lua, &mut *data, args)
}); });
@ -324,7 +321,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 13)?; check_stack(lua.state, 13)?;
push_userdata(lua.state, data.clone())?; push_userdata(lua.state, UserDataCell::new(data.clone()))?;
let data_ptr = ffi::lua_touserdata(lua.state, -1); let data_ptr = ffi::lua_touserdata(lua.state, -1);
// Prepare metatable, add meta methods first and then meta fields // Prepare metatable, add meta methods first and then meta fields
@ -379,7 +376,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
methods_index = Some(ffi::lua_absindex(lua.state, -1)); methods_index = Some(ffi::lua_absindex(lua.state, -1));
} }
init_userdata_metatable::<()>( init_userdata_metatable::<UserDataCell<Rc<RefCell<T>>>>(
lua.state, lua.state,
metatable_index, metatable_index,
field_getters_index, field_getters_index,
@ -427,7 +424,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
mem::transmute(f) mem::transmute(f)
} }
vec![Box::new(seal(take_userdata::<UserDataCell<T>>(state)))] let ud = Box::new(seal(take_userdata::<UserDataCell<Rc<RefCell<T>>>>(state)));
vec![ud]
}); });
self.destructors self.destructors
.borrow_mut() .borrow_mut()

View File

@ -1,9 +1,9 @@
use std::any::TypeId;
use std::cell::{Ref, RefCell, RefMut}; use std::cell::{Ref, RefCell, RefMut};
use std::fmt; use std::fmt;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::ops::Deref; use std::ops::{Deref, DerefMut};
use std::string::String as StdString; use std::string::String as StdString;
use std::sync::Arc;
#[cfg(feature = "async")] #[cfg(feature = "async")]
use std::future::Future; use std::future::Future;
@ -19,13 +19,16 @@ use crate::ffi;
use crate::function::Function; use crate::function::Function;
use crate::lua::Lua; use crate::lua::Lua;
use crate::table::{Table, TablePairs}; use crate::table::{Table, TablePairs};
use crate::types::{LuaRef, MaybeSend}; use crate::types::{Callback, LuaRef, MaybeSend};
use crate::util::{check_stack, get_destructed_userdata_metatable, get_userdata, StackGuard}; use crate::util::{check_stack, get_destructed_userdata_metatable, get_userdata, StackGuard};
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti}; use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti};
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
use crate::value::Value; use crate::value::Value;
#[cfg(feature = "async")]
use crate::types::AsyncCallback;
/// Kinds of metamethods that can be overridden. /// Kinds of metamethods that can be overridden.
/// ///
/// Currently, this mechanism does not allow overriding the `__gc` metamethod, since there is /// Currently, this mechanism does not allow overriding the `__gc` metamethod, since there is
@ -403,6 +406,20 @@ pub trait UserDataMethods<'lua, T: UserData> {
A: FromLuaMulti<'lua>, A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>, R: ToLuaMulti<'lua>,
F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result<R>; F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result<R>;
//
// Below are internal methods used in generated code
//
#[doc(hidden)]
fn add_callback(&mut self, _name: Vec<u8>, _callback: Callback<'lua, 'static>) {}
#[doc(hidden)]
#[cfg(feature = "async")]
fn add_async_callback(&mut self, _name: Vec<u8>, _callback: AsyncCallback<'lua, 'static>) {}
#[doc(hidden)]
fn add_meta_callback(&mut self, _meta: MetaMethod, _callback: Callback<'lua, 'static>) {}
} }
/// Field registry for [`UserData`] implementors. /// Field registry for [`UserData`] implementors.
@ -474,6 +491,16 @@ pub trait UserDataFields<'lua, T: UserData> {
S: Into<MetaMethod>, S: Into<MetaMethod>,
F: 'static + MaybeSend + Fn(&'lua Lua) -> Result<R>, F: 'static + MaybeSend + Fn(&'lua Lua) -> Result<R>,
R: ToLua<'lua>; R: ToLua<'lua>;
//
// Below are internal methods used in generated code
//
#[doc(hidden)]
fn add_field_getter(&mut self, _name: Vec<u8>, _callback: Callback<'lua, 'static>) {}
#[doc(hidden)]
fn add_field_setter(&mut self, _name: Vec<u8>, _callback: Callback<'lua, 'static>) {}
} }
/// Trait for custom userdata types. /// Trait for custom userdata types.
@ -550,26 +577,11 @@ pub trait UserData: Sized {
} }
// Wraps UserData in a way to always implement `serde::Serialize` trait. // Wraps UserData in a way to always implement `serde::Serialize` trait.
pub(crate) enum UserDataCell<T> { pub(crate) struct UserDataCell<T>(RefCell<UserDataWrapped<T>>);
Arc(Arc<RefCell<UserDataWrapped<T>>>),
Plain(RefCell<UserDataWrapped<T>>),
}
impl<T> UserDataCell<T> { impl<T> UserDataCell<T> {
pub(crate) fn new(data: T) -> Self { pub(crate) fn new(data: T) -> Self {
UserDataCell::Plain(RefCell::new(UserDataWrapped { UserDataCell(RefCell::new(UserDataWrapped::new(data)))
data: Box::into_raw(Box::new(data)),
#[cfg(feature = "serialize")]
ser: Box::into_raw(Box::new(UserDataSerializeError)),
}))
}
pub(crate) fn new_arc(data: T) -> Self {
UserDataCell::Arc(Arc::new(RefCell::new(UserDataWrapped {
data: Box::into_raw(Box::new(data)),
#[cfg(feature = "serialize")]
ser: Box::into_raw(Box::new(UserDataSerializeError)),
})))
} }
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
@ -577,40 +589,135 @@ impl<T> UserDataCell<T> {
where where
T: 'static + Serialize, T: 'static + Serialize,
{ {
let data_raw = Box::into_raw(Box::new(data)); UserDataCell(RefCell::new(UserDataWrapped::new_ser(data)))
UserDataCell::Plain(RefCell::new(UserDataWrapped { }
data: data_raw,
ser: data_raw, // Immutably borrows the wrapped value.
})) fn try_borrow(&self) -> Result<UserDataRef<T>> {
self.0
.try_borrow()
.map(|r| UserDataRef(UserDataRefInner::Ref(r)))
.map_err(|_| Error::UserDataBorrowError)
}
// Mutably borrows the wrapped value.
fn try_borrow_mut(&self) -> Result<UserDataRefMut<T>> {
self.0
.try_borrow_mut()
.map(|r| UserDataRefMut(UserDataRefMutInner::Ref(r)))
.map_err(|_| Error::UserDataBorrowMutError)
} }
} }
impl<T> Deref for UserDataCell<T> { #[cfg(feature = "serialize")]
type Target = RefCell<UserDataWrapped<T>>; impl Serialize for UserDataCell<()> {
fn serialize<S>(&self, serializer: S) -> StdResult<S::Ok, S::Error>
where
S: Serializer,
{
let ser = self
.0
.try_borrow()
.map_err(|_| ser::Error::custom(Error::UserDataBorrowError))?
.ser;
unsafe { (&*ser).serialize(serializer) }
}
}
/// A wrapper type for an immutably borrowed value from an `AnyUserData`.
pub struct UserDataRef<'a, T>(UserDataRefInner<'a, T>);
enum UserDataRefInner<'a, T> {
Ref(Ref<'a, UserDataWrapped<T>>),
}
/// A wrapper type for a mutably borrowed value from an `AnyUserData`.
pub struct UserDataRefMut<'a, T>(UserDataRefMutInner<'a, T>);
enum UserDataRefMutInner<'a, T> {
Ref(RefMut<'a, UserDataWrapped<T>>),
}
impl<T> Deref for UserDataRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
match self { match &self.0 {
UserDataCell::Arc(t) => &*t, UserDataRefInner::Ref(x) => &*x,
UserDataCell::Plain(t) => &*t,
} }
} }
} }
impl<T> Clone for UserDataCell<T> { impl<T> Deref for UserDataRefMut<'_, T> {
fn clone(&self) -> Self { type Target = T;
match self {
UserDataCell::Arc(t) => UserDataCell::Arc(t.clone()), fn deref(&self) -> &Self::Target {
UserDataCell::Plain(_) => mlua_panic!("cannot clone non-arc userdata"), match &self.0 {
UserDataRefMutInner::Ref(x) => &*x,
} }
} }
} }
impl<T> DerefMut for UserDataRefMut<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match &mut self.0 {
UserDataRefMutInner::Ref(x) => &mut *x,
}
}
}
impl<T: fmt::Debug> fmt::Debug for UserDataRef<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&*self as &T, f)
}
}
impl<T: fmt::Debug> fmt::Debug for UserDataRefMut<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&*self as &T, f)
}
}
impl<T: fmt::Display> fmt::Display for UserDataRef<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&*self as &T, f)
}
}
impl<T: fmt::Display> fmt::Display for UserDataRefMut<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&*self as &T, f)
}
}
pub(crate) struct UserDataWrapped<T> { pub(crate) struct UserDataWrapped<T> {
pub(crate) data: *mut T, pub(crate) data: *mut T,
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
ser: *mut dyn erased_serde::Serialize, ser: *mut dyn erased_serde::Serialize,
} }
impl<T> UserDataWrapped<T> {
fn new(data: T) -> Self {
UserDataWrapped {
data: Box::into_raw(Box::new(data)),
#[cfg(feature = "serialize")]
ser: Box::into_raw(Box::new(UserDataSerializeError)),
}
}
#[cfg(feature = "serialize")]
fn new_ser(data: T) -> Self
where
T: 'static + Serialize,
{
let data_raw = Box::into_raw(Box::new(data));
UserDataWrapped {
data: data_raw,
ser: data_raw,
}
}
}
impl<T> Drop for UserDataWrapped<T> { impl<T> Drop for UserDataWrapped<T> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
@ -623,20 +730,22 @@ impl<T> Drop for UserDataWrapped<T> {
} }
} }
impl<T> AsRef<T> for UserDataWrapped<T> { impl<T> Deref for UserDataWrapped<T> {
fn as_ref(&self) -> &T { type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.data } unsafe { &*self.data }
} }
} }
impl<T> AsMut<T> for UserDataWrapped<T> { impl<T> DerefMut for UserDataWrapped<T> {
fn as_mut(&mut self) -> &mut T { fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.data } unsafe { &mut *self.data }
} }
} }
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
pub(crate) struct UserDataSerializeError; struct UserDataSerializeError;
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
impl Serialize for UserDataSerializeError { impl Serialize for UserDataSerializeError {
@ -683,26 +792,18 @@ impl<'lua> AnyUserData<'lua> {
/// ///
/// Returns a `UserDataBorrowError` if the userdata is already mutably borrowed. Returns a /// Returns a `UserDataBorrowError` if the userdata is already mutably borrowed. Returns a
/// `UserDataTypeMismatch` if the userdata is not of type `T`. /// `UserDataTypeMismatch` if the userdata is not of type `T`.
pub fn borrow<T: 'static + UserData>(&self) -> Result<Ref<T>> { pub fn borrow<T: 'static + UserData>(&self) -> Result<UserDataRef<T>> {
self.inspect(|cell| { self.inspect(|cell| cell.try_borrow())
let cell_ref = cell.try_borrow().map_err(|_| Error::UserDataBorrowError)?;
Ok(Ref::map(cell_ref, |x| unsafe { &*x.data }))
})
} }
/// Borrow this userdata mutably if it is of type `T`. /// Borrow this userdata mutably if it is of type `T`.
/// ///
/// # Errors /// # Errors
/// ///
/// Returns a `UserDataBorrowMutError` if the userdata is already borrowed. Returns a /// Returns a `UserDataBorrowMutError` if the userdata cannot be mutably borrowed.
/// `UserDataTypeMismatch` if the userdata is not of type `T`. /// Returns a `UserDataTypeMismatch` if the userdata is not of type `T`.
pub fn borrow_mut<T: 'static + UserData>(&self) -> Result<RefMut<T>> { pub fn borrow_mut<T: 'static + UserData>(&self) -> Result<UserDataRefMut<T>> {
self.inspect(|cell| { self.inspect(|cell| cell.try_borrow_mut())
let cell_ref = cell
.try_borrow_mut()
.map_err(|_| Error::UserDataBorrowMutError)?;
Ok(RefMut::map(cell_ref, |x| unsafe { &mut *x.data }))
})
} }
/// Sets an associated value to this `AnyUserData`. /// Sets an associated value to this `AnyUserData`.
@ -726,7 +827,7 @@ impl<'lua> AnyUserData<'lua> {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?; check_stack(lua.state, 3)?;
lua.push_userdata_ref(&self.0)?; lua.push_userdata_ref(&self.0, false)?;
lua.push_value(v)?; lua.push_value(v)?;
ffi::lua_setuservalue(lua.state, -2); ffi::lua_setuservalue(lua.state, -2);
@ -745,7 +846,7 @@ impl<'lua> AnyUserData<'lua> {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?; check_stack(lua.state, 3)?;
lua.push_userdata_ref(&self.0)?; lua.push_userdata_ref(&self.0, false)?;
ffi::lua_getuservalue(lua.state, -1); ffi::lua_getuservalue(lua.state, -1);
lua.pop_value() lua.pop_value()
}; };
@ -776,7 +877,7 @@ impl<'lua> AnyUserData<'lua> {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?; check_stack(lua.state, 3)?;
lua.push_userdata_ref(&self.0)?; lua.push_userdata_ref(&self.0, false)?;
ffi::lua_getmetatable(lua.state, -1); // Checked that non-empty on the previous call ffi::lua_getmetatable(lua.state, -1); // Checked that non-empty on the previous call
Ok(Table(lua.pop_ref())) Ok(Table(lua.pop_ref()))
} }
@ -803,6 +904,25 @@ impl<'lua> AnyUserData<'lua> {
Ok(false) Ok(false)
} }
pub(crate) fn type_id(&self) -> Result<TypeId> {
let lua = self.0.lua;
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 5)?;
// Push userdata with metatable
lua.push_userdata_ref(&self.0, true)?;
// Get the special `__mlua_type_id`
ffi::safe::lua_pushstring(lua.state, "__mlua_type_id")?;
if ffi::lua_rawget(lua.state, -2) != ffi::LUA_TUSERDATA {
return Err(Error::UserDataTypeMismatch);
}
Ok(*(ffi::lua_touserdata(lua.state, -1) as *const TypeId))
}
}
fn inspect<'a, T, R, F>(&'a self, func: F) -> Result<R> fn inspect<'a, T, R, F>(&'a self, func: F) -> Result<R>
where where
T: 'static + UserData, T: 'static + UserData,
@ -928,17 +1048,15 @@ impl<'lua> Serialize for AnyUserData<'lua> {
where where
S: Serializer, S: Serializer,
{ {
let res = (|| unsafe { unsafe {
let lua = self.0.lua; let lua = self.0.lua;
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?; check_stack(lua.state, 3).map_err(ser::Error::custom)?;
lua.push_userdata_ref(&self.0)?; lua.push_userdata_ref(&self.0, false)
.map_err(ser::Error::custom)?;
let ud = &*get_userdata::<UserDataCell<()>>(lua.state, -1); let ud = &*get_userdata::<UserDataCell<()>>(lua.state, -1);
(*ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?.ser) ud.serialize(serializer)
.serialize(serializer) }
.map_err(|err| Error::SerializeError(err.to_string()))
})();
res.map_err(ser::Error::custom)
} }
} }

View File

@ -1,4 +1,7 @@
use std::sync::Arc; use std::sync::{Arc, Mutex, RwLock};
#[cfg(not(feature = "send"))]
use std::{cell::RefCell, rc::Rc};
#[cfg(feature = "lua54")] #[cfg(feature = "lua54")]
use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::atomic::{AtomicI64, Ordering};
@ -451,3 +454,53 @@ fn test_metatable() -> Result<()> {
Ok(()) Ok(())
} }
#[test]
fn test_userdata_wrapped() -> Result<()> {
struct MyUserData(i64);
impl UserData for MyUserData {
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
fields.add_field_method_get("data", |_, this| Ok(this.0));
fields.add_field_method_set("data", |_, this, val| {
this.0 = val;
Ok(())
})
}
}
let lua = Lua::new();
let globals = lua.globals();
#[cfg(not(feature = "send"))]
{
globals.set("rc_refcell_ud", Rc::new(RefCell::new(MyUserData(1))))?;
lua.load(
r#"
rc_refcell_ud.data = rc_refcell_ud.data + 1
assert(rc_refcell_ud.data == 2)
"#,
)
.exec()?;
}
globals.set("arc_mutex_ud", Arc::new(Mutex::new(MyUserData(2))))?;
lua.load(
r#"
arc_mutex_ud.data = arc_mutex_ud.data + 1
assert(arc_mutex_ud.data == 3)
"#,
)
.exec()?;
globals.set("arc_rwlock_ud", Arc::new(RwLock::new(MyUserData(3))))?;
lua.load(
r#"
arc_rwlock_ud.data = arc_rwlock_ud.data + 1
assert(arc_rwlock_ud.data == 4)
"#,
)
.exec()?;
Ok(())
}