Refactor UserData metatables handling

This commit is contained in:
Alex Orlenko 2021-09-26 01:05:37 +01:00
parent 01154c0616
commit d586eef0f5
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
4 changed files with 154 additions and 141 deletions

View File

@ -1,6 +1,6 @@
use std::any::TypeId;
use std::cell::{RefCell, UnsafeCell};
use std::collections::{HashMap, HashSet};
use std::cell::{Ref, RefCell, RefMut, UnsafeCell};
use std::collections::HashMap;
use std::ffi::CString;
use std::fmt;
use std::marker::PhantomData;
@ -19,8 +19,8 @@ use crate::string::String;
use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
Callback, CallbackUpvalue, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number,
RegistryKey,
Callback, CallbackUpvalue, DestructedUserdataMT, HookCallback, Integer, LightUserData, LuaRef,
MaybeSend, Number, RegistryKey,
};
use crate::userdata::{
AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods,
@ -64,7 +64,7 @@ pub struct Lua {
// Data associated with the Lua.
struct ExtraData {
registered_userdata: HashMap<TypeId, c_int>,
registered_userdata_mt: HashSet<isize>,
registered_userdata_mt: HashMap<*const c_void, Option<TypeId>>,
registry_unref_list: Arc<Mutex<Option<Vec<c_int>>>>,
libs: StdLib,
@ -444,7 +444,7 @@ impl Lua {
let extra = Arc::new(UnsafeCell::new(ExtraData {
registered_userdata: HashMap::new(),
registered_userdata_mt: HashSet::new(),
registered_userdata_mt: HashMap::new(),
registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))),
ref_thread,
libs: StdLib::NONE,
@ -469,6 +469,15 @@ impl Lua {
"Error while storing extra data",
);
// Register `DestructedUserdataMT` type
get_destructed_userdata_metatable(main_state);
let destructed_mt_ptr = ffi::lua_topointer(main_state, -1);
(*extra.get()).registered_userdata_mt.insert(
destructed_mt_ptr,
Some(TypeId::of::<DestructedUserdataMT>()),
);
ffi::lua_pop(main_state, 1);
mlua_debug_assert!(
ffi::lua_gettop(main_state) == main_state_top,
"stack leak during creation"
@ -1744,12 +1753,6 @@ impl Lua {
self.push_value(f(self)?)?;
rawset_field(self.state, -2, k.validate()?.name())?;
}
// Add special `__mlua_type_id` field
let type_id_ptr = protect_lua!(self.state, 0, 1, |state| {
ffi::lua_newuserdata(state, mem::size_of::<TypeId>()) as *mut TypeId
})?;
ptr::write(type_id_ptr, type_id);
rawset_field(self.state, -2, "__mlua_type_id")?;
let metatable_index = ffi::lua_absindex(self.state, -1);
let mut extra_tables_count = 0;
@ -1809,51 +1812,60 @@ impl Lua {
// Pop extra tables to get metatable on top of the stack
ffi::lua_pop(self.state, extra_tables_count);
let ptr = ffi::lua_topointer(self.state, -1);
let mt_ptr = ffi::lua_topointer(self.state, -1);
ffi::lua_pushvalue(self.state, -1);
let id = protect_lua!(self.state, 1, 0, |state| {
ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX)
})?;
extra.registered_userdata.insert(type_id, id);
extra.registered_userdata_mt.insert(ptr as isize);
extra.registered_userdata_mt.insert(mt_ptr, Some(type_id));
Ok(())
}
pub(crate) unsafe fn register_userdata_metatable(&self, id: isize) {
(*self.extra.get()).registered_userdata_mt.insert(id);
pub(crate) unsafe fn register_userdata_metatable(
&self,
ptr: *const c_void,
type_id: Option<TypeId>,
) {
let extra = &mut *self.extra.get();
extra.registered_userdata_mt.insert(ptr, type_id);
}
pub(crate) unsafe fn deregister_userdata_metatable(&self, id: isize) {
(*self.extra.get()).registered_userdata_mt.remove(&id);
pub(crate) unsafe fn deregister_userdata_metatable(&self, ptr: *const c_void) {
(*self.extra.get()).registered_userdata_mt.remove(&ptr);
}
// Pushes a LuaRef value onto the stack, checking that it's a registered
// and not destructed UserData.
// Uses 3 stack spaces, does not call checkstack.
pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef, with_mt: bool) -> Result<()> {
// Uses 2 stack spaces, does not call checkstack.
pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef) -> Result<Option<TypeId>> {
self.push_ref(lref);
if ffi::lua_getmetatable(self.state, -1) == 0 {
return Err(Error::UserDataTypeMismatch);
}
// Check that userdata is registered
let ptr = ffi::lua_topointer(self.state, -1);
let mt_ptr = ffi::lua_topointer(self.state, -1);
ffi::lua_pop(self.state, 1);
let extra = &*self.extra.get();
if extra.registered_userdata_mt.contains(&(ptr as isize)) {
if !with_mt {
ffi::lua_pop(self.state, 1);
match extra.registered_userdata_mt.get(&mt_ptr) {
Some(&type_id) if type_id == Some(TypeId::of::<DestructedUserdataMT>()) => {
Err(Error::UserDataDestructed)
}
return Ok(());
Some(&type_id) => Ok(type_id),
None => Err(Error::UserDataTypeMismatch),
}
// Maybe userdata was destructed?
get_destructed_userdata_metatable(self.state);
if ffi::lua_rawequal(self.state, -1, -2) != 0 {
ffi::lua_pop(self.state, 2);
return Err(Error::UserDataDestructed);
}
ffi::lua_pop(self.state, 2);
Err(Error::UserDataTypeMismatch)
}
#[inline]
unsafe fn get_userdata_ref<T>(&self) -> Result<Ref<T>> {
(*get_userdata::<UserDataCell<T>>(self.state, -1)).try_borrow()
}
#[inline]
unsafe fn get_userdata_mut<T>(&self) -> Result<RefMut<T>> {
(*get_userdata::<UserDataCell<T>>(self.state, -1)).try_borrow_mut()
}
// Creates a Function out of a Callback containing a 'static Fn. This is safe ONLY because the
@ -2082,7 +2094,7 @@ impl Lua {
T: 'static + UserData,
{
let _sg = StackGuard::new(self.state);
check_stack(self.state, 2)?;
check_stack(self.state, 3)?;
// It's safe to push userdata first and then metatable.
// If the first push failed, unlikely we moved `data` to allocated memory.
@ -2822,32 +2834,34 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> {
Box::new(move |lua, mut args| {
if let Some(front) = args.pop_front() {
let userdata = AnyUserData::from_lua(front, lua)?;
// Try normal userdata first
let err = match userdata.borrow::<T>() {
Ok(ud) => {
return method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 2)?;
let type_id = lua.push_userdata_ref(&userdata.0)?;
match type_id {
Some(id) if id == TypeId::of::<T>() => {
let ud = lua.get_userdata_ref::<T>()?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => {
let ud = lua.get_userdata_ref::<Rc<RefCell<T>>>()?;
let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => {
let ud = lua.get_userdata_ref::<Arc<Mutex<T>>>()?;
let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => {
let ud = lua.get_userdata_ref::<Arc<RwLock<T>>>()?;
let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
_ => Err(Error::UserDataTypeMismatch),
}
Err(err) => err,
};
match userdata.type_id()? {
id if id == TypeId::of::<T>() => Err(err),
#[cfg(not(feature = "send"))]
id if id == TypeId::of::<Rc<RefCell<T>>>() => {
let ud = userdata.borrow::<Rc<RefCell<T>>>()?;
let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<Mutex<T>>>() => {
let ud = userdata.borrow::<Arc<Mutex<T>>>()?;
let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<RwLock<T>>>() => {
let ud = userdata.borrow::<Arc<RwLock<T>>>()?;
let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?;
method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
_ => Err(Error::UserDataTypeMismatch),
}
} else {
Err(Error::FromLuaConversionError {
@ -2872,35 +2886,38 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> {
let mut method = method
.try_borrow_mut()
.map_err(|_| Error::RecursiveMutCallback)?;
// Try normal userdata first
let err = match userdata.borrow_mut::<T>() {
Ok(mut ud) => {
return method(lua, &mut ud, A::from_lua_multi(args, lua)?)?
.to_lua_multi(lua)
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 2)?;
let type_id = lua.push_userdata_ref(&userdata.0)?;
match type_id {
Some(id) if id == TypeId::of::<T>() => {
let mut ud = lua.get_userdata_mut::<T>()?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => {
let ud = lua.get_userdata_mut::<Rc<RefCell<T>>>()?;
let mut ud = ud
.try_borrow_mut()
.map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => {
let ud = lua.get_userdata_mut::<Arc<Mutex<T>>>()?;
let mut ud =
ud.try_lock().map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => {
let ud = lua.get_userdata_mut::<Arc<RwLock<T>>>()?;
let mut ud =
ud.try_write().map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
_ => Err(Error::UserDataTypeMismatch),
}
Err(err) => err,
};
match userdata.type_id()? {
id if id == TypeId::of::<T>() => Err(err),
#[cfg(not(feature = "send"))]
id if id == TypeId::of::<Rc<RefCell<T>>>() => {
let ud = userdata.borrow::<Rc<RefCell<T>>>()?;
let mut ud = ud
.try_borrow_mut()
.map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<Mutex<T>>>() => {
let ud = userdata.borrow::<Arc<Mutex<T>>>()?;
let mut ud = ud.try_lock().map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
id if id == TypeId::of::<Arc<RwLock<T>>>() => {
let ud = userdata.borrow::<Arc<RwLock<T>>>()?;
let mut ud = ud.try_write().map_err(|_| Error::UserDataBorrowMutError)?;
method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
}
_ => Err(Error::UserDataTypeMismatch),
}
} else {
Err(Error::FromLuaConversionError {
@ -2925,8 +2942,35 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> {
let fut_res = || {
if let Some(front) = args.pop_front() {
let userdata = AnyUserData::from_lua(front, lua)?;
let userdata = userdata.borrow::<T>()?.clone();
Ok(method(lua, userdata, A::from_lua_multi(args, lua)?))
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 2)?;
let type_id = lua.push_userdata_ref(&userdata.0)?;
match type_id {
Some(id) if id == TypeId::of::<T>() => {
let ud = lua.get_userdata_ref::<T>()?;
Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?))
}
#[cfg(not(feature = "send"))]
Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => {
let ud = lua.get_userdata_ref::<Rc<RefCell<T>>>()?;
let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?;
Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?))
}
Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => {
let ud = lua.get_userdata_ref::<Arc<Mutex<T>>>()?;
let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?;
Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?))
}
Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => {
let ud = lua.get_userdata_ref::<Arc<RwLock<T>>>()?;
let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?;
Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?))
}
_ => Err(Error::UserDataTypeMismatch),
}
}
} else {
Err(Error::FromLuaConversionError {
from: "missing argument",

View File

@ -265,7 +265,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?;
lua.push_userdata_ref(&ud.0, false)?;
lua.push_userdata_ref(&ud.0)?;
if get_userdata(lua.state, -1) == data_ptr {
return Ok(());
}
@ -390,12 +390,12 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
+ methods_index.map(|_| 1).unwrap_or(0);
ffi::lua_pop(lua.state, count);
let mt_id = ffi::lua_topointer(lua.state, -1);
let mt_ptr = ffi::lua_topointer(lua.state, -1);
// Write userdata just before attaching metatable with `__gc` metamethod
ptr::write(data_ptr as _, UserDataCell::new(data));
ffi::lua_setmetatable(lua.state, -2);
let ud = AnyUserData(lua.pop_ref());
lua.register_userdata_metatable(mt_id as isize);
lua.register_userdata_metatable(mt_ptr, None);
#[cfg(any(feature = "lua51", feature = "luajit"))]
let newtable = lua.create_table()?;
@ -410,9 +410,9 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// Deregister metatable
ffi::lua_getmetatable(state, -1);
let mt_id = ffi::lua_topointer(state, -1);
let mt_ptr = ffi::lua_topointer(state, -1);
ffi::lua_pop(state, 1);
ud.lua.deregister_userdata_metatable(mt_id as isize);
ud.lua.deregister_userdata_metatable(mt_ptr);
// Clear uservalue
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]

View File

@ -63,6 +63,8 @@ pub trait MaybeSend {}
#[cfg(not(feature = "send"))]
impl<T> MaybeSend for T {}
pub(crate) struct DestructedUserdataMT;
/// An auto generated key into the Lua registry.
///
/// This is a handle to a value stored inside the Lua registry. It is not automatically

View File

@ -21,9 +21,7 @@ use crate::function::Function;
use crate::lua::Lua;
use crate::table::{Table, TablePairs};
use crate::types::{Callback, LuaRef, MaybeSend};
use crate::util::{
check_stack, get_destructed_userdata_metatable, get_userdata, push_string, StackGuard,
};
use crate::util::{check_stack, get_userdata, StackGuard};
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti};
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
@ -613,7 +611,7 @@ impl<T> UserDataCell<T> {
// Immutably borrows the wrapped value.
#[inline]
fn try_borrow(&self) -> Result<Ref<T>> {
pub(crate) fn try_borrow(&self) -> Result<Ref<T>> {
self.0
.try_borrow()
.map(|r| Ref::map(r, |r| r.deref()))
@ -622,7 +620,7 @@ impl<T> UserDataCell<T> {
// Mutably borrows the wrapped value.
#[inline]
fn try_borrow_mut(&self) -> Result<RefMut<T>> {
pub(crate) fn try_borrow_mut(&self) -> Result<RefMut<T>> {
self.0
.try_borrow_mut()
.map(|r| RefMut::map(r, |r| r.deref_mut()))
@ -771,7 +769,7 @@ impl<'lua> AnyUserData<'lua> {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?;
lua.push_userdata_ref(&self.0, false)?;
lua.push_userdata_ref(&self.0)?;
lua.push_value(v)?;
ffi::lua_setuservalue(lua.state, -2);
@ -790,7 +788,7 @@ impl<'lua> AnyUserData<'lua> {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?;
lua.push_userdata_ref(&self.0, false)?;
lua.push_userdata_ref(&self.0)?;
ffi::lua_getuservalue(lua.state, -1);
lua.pop_value()
};
@ -821,7 +819,7 @@ impl<'lua> AnyUserData<'lua> {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?;
lua.push_userdata_ref(&self.0, false)?;
lua.push_userdata_ref(&self.0)?;
ffi::lua_getmetatable(lua.state, -1); // Checked that non-empty on the previous call
Ok(Table(lua.pop_ref()))
}
@ -848,25 +846,6 @@ impl<'lua> AnyUserData<'lua> {
Ok(false)
}
pub(crate) fn type_id(&self) -> Result<TypeId> {
let lua = self.0.lua;
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 5)?;
// Push userdata with metatable
lua.push_userdata_ref(&self.0, true)?;
// Get the special `__mlua_type_id`
push_string(lua.state, "__mlua_type_id")?;
if ffi::lua_rawget(lua.state, -2) != ffi::LUA_TUSERDATA {
return Err(Error::UserDataTypeMismatch);
}
Ok(*(ffi::lua_touserdata(lua.state, -1) as *const TypeId))
}
}
fn inspect<'a, T, R, F>(&'a self, func: F) -> Result<R>
where
T: 'static + UserData,
@ -875,25 +854,14 @@ impl<'lua> AnyUserData<'lua> {
let lua = self.0.lua;
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3)?;
check_stack(lua.state, 2)?;
lua.push_ref(&self.0);
if ffi::lua_getmetatable(lua.state, -1) == 0 {
return Err(Error::UserDataTypeMismatch);
}
lua.push_userdata_metatable::<T>()?;
if ffi::lua_rawequal(lua.state, -1, -2) == 0 {
// Maybe UserData destructed?
ffi::lua_pop(lua.state, 1);
get_destructed_userdata_metatable(lua.state);
if ffi::lua_rawequal(lua.state, -1, -2) == 1 {
Err(Error::UserDataDestructed)
} else {
Err(Error::UserDataTypeMismatch)
let type_id = lua.push_userdata_ref(&self.0)?;
match type_id {
Some(type_id) if type_id == TypeId::of::<T>() => {
func(&*get_userdata::<UserDataCell<T>>(lua.state, -1))
}
} else {
func(&*get_userdata::<UserDataCell<T>>(lua.state, -3))
_ => Err(Error::UserDataTypeMismatch),
}
}
}
@ -997,8 +965,7 @@ impl<'lua> Serialize for AnyUserData<'lua> {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3).map_err(ser::Error::custom)?;
lua.push_userdata_ref(&self.0, false)
.map_err(ser::Error::custom)?;
lua.push_userdata_ref(&self.0).map_err(ser::Error::custom)?;
let ud = &*get_userdata::<UserDataCell<c_void>>(lua.state, -1);
let data =
ud.0.try_borrow()