Add (hidden) method `UserData::take()` to take out value from userdata

This commit is contained in:
Alex Orlenko 2021-10-05 13:04:43 +01:00
parent 235fba821e
commit a544e41b33
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
5 changed files with 176 additions and 56 deletions

View File

@ -2129,10 +2129,11 @@ impl Lua {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 3)?;
// It's safe to push userdata first and then metatable.
// If the first push failed, unlikely we moved `data` to allocated memory.
push_userdata(self.state, data)?;
// We push metatable first to ensure having correct metatable with `__gc` method
ffi::lua_pushnil(self.state);
self.push_userdata_metatable::<T>()?;
push_userdata(self.state, data)?;
ffi::lua_replace(self.state, -3);
ffi::lua_setmetatable(self.state, -2);
Ok(AnyUserData(self.pop_ref()))

View File

@ -192,9 +192,10 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let _sg = StackGuard::new(state);
assert_stack(state, 2);
ud.lua.push_ref(&ud);
// We know the destructor has not run yet because we hold a reference to the userdata.
// Check that userdata is not destructed (via `take()` call)
if ud.lua.push_userdata_ref(&ud).is_err() {
return vec![];
}
// Clear uservalue
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
@ -404,9 +405,10 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let _sg = StackGuard::new(state);
assert_stack(state, 2);
ud.lua.push_ref(&ud);
// We know the destructor has not run yet because we hold a reference to the userdata.
// Check that userdata is valid (very likely)
if ud.lua.push_userdata_ref(&ud).is_err() {
return vec![];
}
// Deregister metatable
ffi::lua_getmetatable(state, -1);

View File

@ -11,7 +11,6 @@ use std::future::Future;
#[cfg(feature = "serialize")]
use {
serde::ser::{self, Serialize, Serializer},
std::os::raw::c_void,
std::result::Result as StdResult,
};
@ -21,7 +20,7 @@ use crate::function::Function;
use crate::lua::Lua;
use crate::table::{Table, TablePairs};
use crate::types::{Callback, LuaRef, MaybeSend};
use crate::util::{check_stack, get_userdata, StackGuard};
use crate::util::{check_stack, get_userdata, take_userdata, StackGuard};
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti};
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
@ -626,18 +625,24 @@ impl<T> UserDataCell<T> {
.map(|r| RefMut::map(r, |r| r.deref_mut()))
.map_err(|_| Error::UserDataBorrowMutError)
}
// Consumes this `UserDataCell`, returning the wrapped value.
#[inline]
fn into_inner(self) -> T {
self.0.into_inner().into_inner()
}
}
pub(crate) enum UserDataWrapped<T> {
Default(T),
Default(Box<T>),
#[cfg(feature = "serialize")]
Serializable(*mut T, *const dyn erased_serde::Serialize),
Serializable(Box<dyn erased_serde::Serialize>),
}
impl<T> UserDataWrapped<T> {
#[inline]
fn new(data: T) -> Self {
UserDataWrapped::Default(data)
UserDataWrapped::Default(Box::new(data))
}
#[cfg(feature = "serialize")]
@ -646,16 +651,15 @@ impl<T> UserDataWrapped<T> {
where
T: 'static + Serialize,
{
let data_raw = Box::into_raw(Box::new(data));
UserDataWrapped::Serializable(data_raw, data_raw)
UserDataWrapped::Serializable(Box::new(data))
}
}
#[cfg(feature = "serialize")]
impl<T> Drop for UserDataWrapped<T> {
fn drop(&mut self) {
if let UserDataWrapped::Serializable(data, _) = *self {
drop(unsafe { Box::from_raw(data) });
#[inline]
fn into_inner(self) -> T {
match self {
Self::Default(data) => *data,
#[cfg(feature = "serialize")]
Self::Serializable(data) => unsafe { *Box::from_raw(Box::into_raw(data) as *mut T) },
}
}
}
@ -668,7 +672,9 @@ impl<T> Deref for UserDataWrapped<T> {
match self {
Self::Default(data) => data,
#[cfg(feature = "serialize")]
Self::Serializable(data, _) => unsafe { &**data },
Self::Serializable(data) => unsafe {
&*(data.as_ref() as *const _ as *const Self::Target)
},
}
}
}
@ -679,7 +685,9 @@ impl<T> DerefMut for UserDataWrapped<T> {
match self {
Self::Default(data) => data,
#[cfg(feature = "serialize")]
Self::Serializable(data, _) => unsafe { &mut **data },
Self::Serializable(data) => unsafe {
&mut *(data.as_mut() as *mut _ as *mut Self::Target)
},
}
}
}
@ -748,6 +756,35 @@ impl<'lua> AnyUserData<'lua> {
self.inspect(|cell| cell.try_borrow_mut())
}
/// Takes out the value of `UserData` and sets the special "destructed" metatable that prevents
/// any further operations with this userdata.
#[doc(hidden)]
pub fn take<T: 'static + UserData>(&self) -> Result<T> {
let lua = self.0.lua;
unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 2)?;
let type_id = lua.push_userdata_ref(&self.0)?;
match type_id {
Some(type_id) if type_id == TypeId::of::<T>() => {
// Try to borrow userdata exclusively
let _ = (*get_userdata::<UserDataCell<T>>(lua.state, -1)).try_borrow_mut()?;
// Clear uservalue
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
ffi::lua_pushnil(lua.state);
#[cfg(any(feature = "lua51", feature = "luajit"))]
protect_lua!(lua.state, 0, 1, fn(state) ffi::lua_newtable(state))?;
ffi::lua_setuservalue(lua.state, -2);
Ok(take_userdata::<UserDataCell<T>>(lua.state).into_inner())
}
_ => Err(Error::UserDataTypeMismatch),
}
}
}
/// Sets an associated value to this `AnyUserData`.
///
/// The value may be any Lua value whatsoever, and can be retrieved with [`get_user_value`].
@ -960,20 +997,19 @@ impl<'lua> Serialize for AnyUserData<'lua> {
where
S: Serializer,
{
unsafe {
let lua = self.0.lua;
let lua = self.0.lua;
let data = unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 3).map_err(ser::Error::custom)?;
lua.push_userdata_ref(&self.0).map_err(ser::Error::custom)?;
let ud = &*get_userdata::<UserDataCell<c_void>>(lua.state, -1);
let data =
ud.0.try_borrow()
.map_err(|_| ser::Error::custom(Error::UserDataBorrowError))?;
match *data {
UserDataWrapped::Default(_) => UserDataSerializeError.serialize(serializer),
UserDataWrapped::Serializable(_, ser) => (&*ser).serialize(serializer),
}
let ud = &*get_userdata::<UserDataCell<()>>(lua.state, -1);
ud.0.try_borrow()
.map_err(|_| ser::Error::custom(Error::UserDataBorrowError))?
};
match &*data {
UserDataWrapped::Default(_) => UserDataSerializeError.serialize(serializer),
UserDataWrapped::Serializable(ser) => ser.serialize(serializer),
}
}
}

View File

@ -805,6 +805,7 @@ pub unsafe fn init_error_registry(state: *mut ffi::lua_State) -> Result<()> {
// Create destructed userdata metatable
unsafe extern "C" fn destructed_error(state: *mut ffi::lua_State) -> c_int {
// TODO: Consider changing error to UserDataDestructed in v0.7
callback_error(state, |_| Err(Error::CallbackDestructed))
}

View File

@ -36,6 +36,7 @@ fn test_user_data() -> Result<()> {
#[test]
fn test_methods() -> Result<()> {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
struct MyUserData(i64);
impl UserData for MyUserData {
@ -48,29 +49,38 @@ fn test_methods() -> Result<()> {
}
}
let lua = Lua::new();
let globals = lua.globals();
let userdata = lua.create_userdata(MyUserData(42))?;
globals.set("userdata", userdata.clone())?;
lua.load(
r#"
function get_it()
return userdata:get_value()
end
fn check_methods(lua: &Lua, userdata: AnyUserData) -> Result<()> {
let globals = lua.globals();
globals.set("userdata", userdata.clone())?;
lua.load(
r#"
function get_it()
return userdata:get_value()
end
function set_it(i)
return userdata:set_value(i)
end
"#,
)
.exec()?;
let get = globals.get::<_, Function>("get_it")?;
let set = globals.get::<_, Function>("set_it")?;
assert_eq!(get.call::<_, i64>(())?, 42);
userdata.borrow_mut::<MyUserData>()?.0 = 64;
assert_eq!(get.call::<_, i64>(())?, 64);
set.call::<_, ()>(100)?;
assert_eq!(get.call::<_, i64>(())?, 100);
function set_it(i)
return userdata:set_value(i)
end
"#,
)
.exec()?;
let get = globals.get::<_, Function>("get_it")?;
let set = globals.get::<_, Function>("set_it")?;
assert_eq!(get.call::<_, i64>(())?, 42);
userdata.borrow_mut::<MyUserData>()?.0 = 64;
assert_eq!(get.call::<_, i64>(())?, 64);
set.call::<_, ()>(100)?;
assert_eq!(get.call::<_, i64>(())?, 100);
Ok(())
}
let lua = Lua::new();
check_methods(&lua, lua.create_userdata(MyUserData(42))?)?;
// Additionally check serializable userdata
#[cfg(feature = "serialize")]
check_methods(&lua, lua.create_ser_userdata(MyUserData(42))?)?;
Ok(())
}
@ -252,6 +262,76 @@ fn test_gc_userdata() -> Result<()> {
Ok(())
}
#[test]
fn test_userdata_take() -> Result<()> {
#[derive(Debug)]
struct MyUserdata(Arc<i64>);
impl UserData for MyUserdata {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("num", |_, this, ()| Ok(*this.0))
}
}
#[cfg(feature = "serialize")]
impl serde::Serialize for MyUserdata {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_i64(*self.0)
}
}
fn check_userdata_take(lua: &Lua, userdata: AnyUserData, rc: Arc<i64>) -> Result<()> {
lua.globals().set("userdata", userdata.clone())?;
assert_eq!(Arc::strong_count(&rc), 2);
let userdata_copy = userdata.clone();
{
let _value = userdata.borrow::<MyUserdata>()?;
// We should not be able to take userdata if it's borrowed
match userdata_copy.take::<MyUserdata>() {
Err(Error::UserDataBorrowMutError) => {}
r => panic!("expected `UserDataBorrowMutError` error, got {:?}", r),
}
}
let value = userdata_copy.take::<MyUserdata>()?;
assert_eq!(*value.0, 18);
drop(value);
assert_eq!(Arc::strong_count(&rc), 1);
match userdata.borrow::<MyUserdata>() {
Err(Error::UserDataDestructed) => {}
r => panic!("expected `UserDataDestructed` error, got {:?}", r),
}
match lua.load("userdata:num()").exec() {
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
Error::CallbackDestructed => {}
err => panic!("expected `CallbackDestructed`, got {:?}", err),
},
r => panic!("improper return for destructed userdata: {:?}", r),
}
Ok(())
}
let lua = Lua::new();
let rc = Arc::new(18);
let userdata = lua.create_userdata(MyUserdata(rc.clone()))?;
check_userdata_take(&lua, userdata, rc)?;
// Additionally check serializable userdata
#[cfg(feature = "serialize")]
{
let rc = Arc::new(18);
let userdata = lua.create_ser_userdata(MyUserdata(rc.clone()))?;
check_userdata_take(&lua, userdata, rc)?;
}
Ok(())
}
#[test]
fn test_destroy_userdata() -> Result<()> {
struct MyUserdata(Arc<()>);