Refactor UserDataCell

This commit is contained in:
Alex Orlenko 2021-04-27 18:31:57 +01:00
parent b5f1325f2f
commit 463fc646bc
4 changed files with 93 additions and 54 deletions

View File

@ -19,10 +19,9 @@ use crate::table::Table;
use crate::thread::Thread; use crate::thread::Thread;
use crate::types::{ use crate::types::{
Callback, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey, Callback, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey,
UserDataCell,
}; };
use crate::userdata::{ use crate::userdata::{
AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMethods, UserDataWrapped, AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods,
}; };
use crate::util::{ use crate::util::{
assert_stack, callback_error, check_stack, get_destructed_userdata_metatable, get_gc_userdata, assert_stack, callback_error, check_stack, get_destructed_userdata_metatable, get_gc_userdata,
@ -982,7 +981,7 @@ impl Lua {
where where
T: 'static + MaybeSend + UserData, T: 'static + MaybeSend + UserData,
{ {
unsafe { self.make_userdata(UserDataWrapped::new(data)) } unsafe { self.make_userdata(UserDataCell::new(data)) }
} }
/// Create a Lua userdata object from a custom serializable userdata type. /// Create a Lua userdata object from a custom serializable userdata type.
@ -994,7 +993,7 @@ impl Lua {
where where
T: 'static + MaybeSend + UserData + Serialize, T: 'static + MaybeSend + UserData + Serialize,
{ {
unsafe { self.make_userdata(UserDataWrapped::new_ser(data)) } unsafe { self.make_userdata(UserDataCell::new_ser(data)) }
} }
/// Returns a handle to the global environment. /// Returns a handle to the global environment.
@ -1825,7 +1824,7 @@ impl Lua {
.into_function() .into_function()
} }
pub(crate) unsafe fn make_userdata<T>(&self, data: UserDataWrapped<T>) -> Result<AnyUserData> pub(crate) unsafe fn make_userdata<T>(&self, data: UserDataCell<T>) -> Result<AnyUserData>
where where
T: 'static + UserData, T: 'static + UserData,
{ {
@ -1833,7 +1832,7 @@ impl Lua {
assert_stack(self.state, 4); assert_stack(self.state, 4);
let ud_index = self.userdata_metatable::<T>()?; let ud_index = self.userdata_metatable::<T>()?;
push_userdata::<UserDataCell<T>>(self.state, RefCell::new(data))?; push_userdata(self.state, data)?;
ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ud_index as Integer); ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ud_index as Integer);
ffi::lua_setmetatable(self.state, -2); ffi::lua_setmetatable(self.state, -2);

View File

@ -2,7 +2,7 @@ use std::any::Any;
use std::cell::{Cell, Ref, RefCell, RefMut}; use std::cell::{Cell, Ref, RefCell, RefMut};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
use std::os::raw::c_int; use std::os::raw::{c_int, c_void};
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
use serde::Serialize; use serde::Serialize;
@ -11,9 +11,9 @@ use crate::error::{Error, Result};
use crate::ffi; use crate::ffi;
use crate::function::Function; use crate::function::Function;
use crate::lua::Lua; use crate::lua::Lua;
use crate::types::{Callback, LuaRef, MaybeSend, UserDataCell}; use crate::types::{Callback, LuaRef, MaybeSend};
use crate::userdata::{ use crate::userdata::{
AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMethods, UserDataWrapped, AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods,
}; };
use crate::util::{ use crate::util::{
assert_stack, get_userdata, init_userdata_metatable, push_userdata, take_userdata, StackGuard, assert_stack, get_userdata, init_userdata_metatable, push_userdata, take_userdata, StackGuard,
@ -150,7 +150,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
where where
T: 'static + UserData, T: 'static + UserData,
{ {
self.create_userdata_inner(UserDataWrapped::new(data)) self.create_userdata_inner(UserDataCell::new(data))
} }
/// Create a Lua userdata object from a custom serializable userdata type. /// Create a Lua userdata object from a custom serializable userdata type.
@ -170,10 +170,10 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
where where
T: 'static + UserData + Serialize, T: 'static + UserData + Serialize,
{ {
self.create_userdata_inner(UserDataWrapped::new_ser(data)) self.create_userdata_inner(UserDataCell::new_ser(data))
} }
fn create_userdata_inner<T>(&self, data: UserDataWrapped<T>) -> Result<AnyUserData<'lua>> fn create_userdata_inner<T>(&self, data: UserDataCell<T>) -> Result<AnyUserData<'lua>>
where where
T: 'static + UserData, T: 'static + UserData,
{ {
@ -236,7 +236,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
where where
T: 'scope + UserData, T: 'scope + UserData,
{ {
let data = UserDataCell::new(UserDataWrapped::new(data)); let data = UserDataCell::new_arc(data);
// 'callback outliving 'scope is a lie to make the types work out, required due to the // 'callback outliving 'scope is a lie to make the types work out, required due to the
// inability to work with the more correct callback type that is universally quantified over // inability to work with the more correct callback type that is universally quantified over
@ -245,7 +245,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// parameters. // parameters.
fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>( fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>(
scope: &Scope<'lua, 'scope>, scope: &Scope<'lua, 'scope>,
data: *mut UserDataCell<T>, data: UserDataCell<T>,
data_ptr: *mut c_void,
method: NonStaticMethod<'callback, T>, method: NonStaticMethod<'callback, T>,
) -> Result<Function<'lua>> { ) -> Result<Function<'lua>> {
// On methods that actually receive the userdata, we fake a type check on the passed in // On methods that actually receive the userdata, we fake a type check on the passed in
@ -255,14 +256,13 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// with a type mismatch, but here without this check would proceed as though you had // with a type mismatch, but here without this check would proceed as though you had
// called the method on the original value (since we otherwise completely ignore the // called the method on the original value (since we otherwise completely ignore the
// first argument). // first argument).
let check_data = data;
let check_ud_type = move |lua: &'callback Lua, value| { let check_ud_type = move |lua: &'callback Lua, value| {
if let Some(Value::UserData(ud)) = value { if let Some(Value::UserData(ud)) = value {
unsafe { unsafe {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
assert_stack(lua.state, 3); assert_stack(lua.state, 3);
lua.push_userdata_ref(&ud.0)?; lua.push_userdata_ref(&ud.0)?;
if get_userdata(lua.state, -1) == check_data { if get_userdata(lua.state, -1) == data_ptr {
return Ok(()); return Ok(());
} }
} }
@ -274,7 +274,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
NonStaticMethod::Method(method) => { NonStaticMethod::Method(method) => {
let f = Box::new(move |lua, mut args: MultiValue<'callback>| { let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
check_ud_type(lua, args.pop_front())?; check_ud_type(lua, args.pop_front())?;
let data = unsafe { &*data } let data = data
.try_borrow() .try_borrow()
.map(|cell| Ref::map(cell, AsRef::as_ref)) .map(|cell| Ref::map(cell, AsRef::as_ref))
.map_err(|_| Error::UserDataBorrowError)?; .map_err(|_| Error::UserDataBorrowError)?;
@ -289,7 +289,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let mut method = method let mut method = method
.try_borrow_mut() .try_borrow_mut()
.map_err(|_| Error::RecursiveMutCallback)?; .map_err(|_| Error::RecursiveMutCallback)?;
let mut data = unsafe { &mut *data } let mut data = data
.try_borrow_mut() .try_borrow_mut()
.map(|cell| RefMut::map(cell, AsMut::as_mut)) .map(|cell| RefMut::map(cell, AsMut::as_mut))
.map_err(|_| Error::UserDataBorrowMutError)?; .map_err(|_| Error::UserDataBorrowMutError)?;
@ -322,15 +322,16 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
assert_stack(lua.state, 13); assert_stack(lua.state, 13);
push_userdata(lua.state, data)?; push_userdata(lua.state, data.clone())?;
let data = get_userdata::<UserDataCell<T>>(lua.state, -1); let data_ptr = ffi::lua_touserdata(lua.state, -1);
// Prepare metatable, add meta methods first and then meta fields // Prepare metatable, add meta methods first and then meta fields
let meta_methods_nrec = ud_methods.meta_methods.len() + ud_fields.meta_fields.len() + 1; let meta_methods_nrec = ud_methods.meta_methods.len() + ud_fields.meta_fields.len() + 1;
ffi::safe::lua_createtable(lua.state, 0, meta_methods_nrec as c_int)?; ffi::safe::lua_createtable(lua.state, 0, meta_methods_nrec as c_int)?;
for (k, m) in ud_methods.meta_methods { for (k, m) in ud_methods.meta_methods {
lua.push_value(Value::Function(wrap_method(self, data, m)?))?; let data = data.clone();
lua.push_value(Value::Function(wrap_method(self, data, data_ptr, m)?))?;
ffi::safe::lua_rawsetfield(lua.state, -2, k.validate()?.name())?; ffi::safe::lua_rawsetfield(lua.state, -2, k.validate()?.name())?;
} }
for (k, f) in ud_fields.meta_fields { for (k, f) in ud_fields.meta_fields {
@ -344,7 +345,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
if field_getters_nrec > 0 { if field_getters_nrec > 0 {
ffi::safe::lua_createtable(lua.state, 0, field_getters_nrec as c_int)?; ffi::safe::lua_createtable(lua.state, 0, field_getters_nrec as c_int)?;
for (k, m) in ud_fields.field_getters { for (k, m) in ud_fields.field_getters {
lua.push_value(Value::Function(wrap_method(self, data, m)?))?; let data = data.clone();
lua.push_value(Value::Function(wrap_method(self, data, data_ptr, m)?))?;
ffi::safe::lua_rawsetfield(lua.state, -2, &k)?; ffi::safe::lua_rawsetfield(lua.state, -2, &k)?;
} }
field_getters_index = Some(ffi::lua_absindex(lua.state, -1)); field_getters_index = Some(ffi::lua_absindex(lua.state, -1));
@ -355,7 +357,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
if field_setters_nrec > 0 { if field_setters_nrec > 0 {
ffi::safe::lua_createtable(lua.state, 0, field_setters_nrec as c_int)?; ffi::safe::lua_createtable(lua.state, 0, field_setters_nrec as c_int)?;
for (k, m) in ud_fields.field_setters { for (k, m) in ud_fields.field_setters {
lua.push_value(Value::Function(wrap_method(self, data, m)?))?; let data = data.clone();
lua.push_value(Value::Function(wrap_method(self, data, data_ptr, m)?))?;
ffi::safe::lua_rawsetfield(lua.state, -2, &k)?; ffi::safe::lua_rawsetfield(lua.state, -2, &k)?;
} }
field_setters_index = Some(ffi::lua_absindex(lua.state, -1)); field_setters_index = Some(ffi::lua_absindex(lua.state, -1));
@ -367,7 +370,8 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// Create table used for methods lookup // Create table used for methods lookup
ffi::safe::lua_createtable(lua.state, 0, methods_nrec as c_int)?; ffi::safe::lua_createtable(lua.state, 0, methods_nrec as c_int)?;
for (k, m) in ud_methods.methods { for (k, m) in ud_methods.methods {
lua.push_value(Value::Function(wrap_method(self, data, m)?))?; let data = data.clone();
lua.push_value(Value::Function(wrap_method(self, data, data_ptr, m)?))?;
ffi::safe::lua_rawsetfield(lua.state, -2, &k)?; ffi::safe::lua_rawsetfield(lua.state, -2, &k)?;
} }
methods_index = Some(ffi::lua_absindex(lua.state, -1)); methods_index = Some(ffi::lua_absindex(lua.state, -1));
@ -414,9 +418,13 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
ud.lua.push_ref(&newtable.0); ud.lua.push_ref(&newtable.0);
ffi::lua_setuservalue(state, -2); ffi::lua_setuservalue(state, -2);
// We cannot put `T` into the vec because `T` does not implement `Any` // A hack to drop non-static `T`
drop(take_userdata::<UserDataCell<T>>(state)); unsafe fn seal<T>(t: T) -> Box<dyn FnOnce() + 'static> {
vec![] let f: Box<dyn FnOnce()> = Box::new(move || drop(t));
mem::transmute(f)
}
vec![Box::new(seal(take_userdata::<UserDataCell<T>>(state)))]
}); });
self.destructors self.destructors
.borrow_mut() .borrow_mut()

View File

@ -10,7 +10,6 @@ use crate::error::Result;
use crate::ffi; use crate::ffi;
use crate::hook::Debug; use crate::hook::Debug;
use crate::lua::Lua; use crate::lua::Lua;
use crate::userdata::UserDataWrapped;
use crate::util::{assert_stack, StackGuard}; use crate::util::{assert_stack, StackGuard};
use crate::value::MultiValue; use crate::value::MultiValue;
@ -32,8 +31,6 @@ pub(crate) type AsyncCallback<'lua, 'a> =
pub(crate) type HookCallback = Arc<RefCell<dyn FnMut(&Lua, Debug) -> Result<()>>>; pub(crate) type HookCallback = Arc<RefCell<dyn FnMut(&Lua, Debug) -> Result<()>>>;
pub(crate) type UserDataCell<T> = RefCell<UserDataWrapped<T>>;
#[cfg(feature = "send")] #[cfg(feature = "send")]
pub trait MaybeSend: Send {} pub trait MaybeSend: Send {}
#[cfg(feature = "send")] #[cfg(feature = "send")]

View File

@ -1,7 +1,9 @@
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefCell, RefMut};
use std::fmt; use std::fmt;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::string::String as StdString; use std::string::String as StdString;
use std::sync::Arc;
#[cfg(feature = "async")] #[cfg(feature = "async")]
use std::future::Future; use std::future::Future;
@ -17,7 +19,7 @@ use crate::ffi;
use crate::function::Function; use crate::function::Function;
use crate::lua::Lua; use crate::lua::Lua;
use crate::table::{Table, TablePairs}; use crate::table::{Table, TablePairs};
use crate::types::{Integer, LuaRef, MaybeSend, UserDataCell}; use crate::types::{Integer, LuaRef, MaybeSend};
use crate::util::{assert_stack, get_destructed_userdata_metatable, get_userdata, StackGuard}; use crate::util::{assert_stack, get_destructed_userdata_metatable, get_userdata, StackGuard};
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti, Value}; use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti, Value};
@ -544,6 +546,61 @@ pub trait UserData: Sized {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(_methods: &mut M) {} fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(_methods: &mut M) {}
} }
pub(crate) enum UserDataCell<T> {
Arc(Arc<RefCell<UserDataWrapped<T>>>),
Plain(RefCell<UserDataWrapped<T>>),
}
impl<T> UserDataCell<T> {
pub(crate) fn new(data: T) -> Self {
UserDataCell::Plain(RefCell::new(UserDataWrapped {
data: Box::into_raw(Box::new(data)),
#[cfg(feature = "serialize")]
ser: Box::into_raw(Box::new(UserDataSerializeError)),
}))
}
pub(crate) fn new_arc(data: T) -> Self {
UserDataCell::Arc(Arc::new(RefCell::new(UserDataWrapped {
data: Box::into_raw(Box::new(data)),
#[cfg(feature = "serialize")]
ser: Box::into_raw(Box::new(UserDataSerializeError)),
})))
}
#[cfg(feature = "serialize")]
pub(crate) fn new_ser(data: T) -> Self
where
T: 'static + Serialize,
{
let data_raw = Box::into_raw(Box::new(data));
UserDataCell::Plain(RefCell::new(UserDataWrapped {
data: data_raw,
ser: data_raw,
}))
}
}
impl<T> Deref for UserDataCell<T> {
type Target = RefCell<UserDataWrapped<T>>;
fn deref(&self) -> &Self::Target {
match self {
UserDataCell::Arc(t) => &*t,
UserDataCell::Plain(t) => &*t,
}
}
}
impl<T> Clone for UserDataCell<T> {
fn clone(&self) -> Self {
match self {
UserDataCell::Arc(t) => UserDataCell::Arc(t.clone()),
UserDataCell::Plain(_) => mlua_panic!("cannot clone plain userdata"),
}
}
}
pub(crate) struct UserDataWrapped<T> { pub(crate) struct UserDataWrapped<T> {
pub(crate) data: *mut T, pub(crate) data: *mut T,
#[cfg(feature = "serialize")] #[cfg(feature = "serialize")]
@ -562,28 +619,6 @@ impl<T> Drop for UserDataWrapped<T> {
} }
} }
impl<T> UserDataWrapped<T> {
pub(crate) fn new(data: T) -> Self {
UserDataWrapped {
data: Box::into_raw(Box::new(data)),
#[cfg(feature = "serialize")]
ser: Box::into_raw(Box::new(UserDataSerializeError)),
}
}
#[cfg(feature = "serialize")]
pub(crate) fn new_ser(data: T) -> Self
where
T: 'static + Serialize,
{
let data_raw = Box::into_raw(Box::new(data));
UserDataWrapped {
data: data_raw,
ser: data_raw,
}
}
}
impl<T> AsRef<T> for UserDataWrapped<T> { impl<T> AsRef<T> for UserDataWrapped<T> {
fn as_ref(&self) -> &T { fn as_ref(&self) -> &T {
unsafe { &*self.data } unsafe { &*self.data }