Add (hidden) method `UserData::take()` to take out value from userdata
This commit is contained in:
parent
235fba821e
commit
a544e41b33
|
@ -2129,10 +2129,11 @@ impl Lua {
|
|||
let _sg = StackGuard::new(self.state);
|
||||
check_stack(self.state, 3)?;
|
||||
|
||||
// It's safe to push userdata first and then metatable.
|
||||
// If the first push failed, unlikely we moved `data` to allocated memory.
|
||||
push_userdata(self.state, data)?;
|
||||
// We push metatable first to ensure having correct metatable with `__gc` method
|
||||
ffi::lua_pushnil(self.state);
|
||||
self.push_userdata_metatable::<T>()?;
|
||||
push_userdata(self.state, data)?;
|
||||
ffi::lua_replace(self.state, -3);
|
||||
ffi::lua_setmetatable(self.state, -2);
|
||||
|
||||
Ok(AnyUserData(self.pop_ref()))
|
||||
|
|
14
src/scope.rs
14
src/scope.rs
|
@ -192,9 +192,10 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
|
|||
let _sg = StackGuard::new(state);
|
||||
assert_stack(state, 2);
|
||||
|
||||
ud.lua.push_ref(&ud);
|
||||
|
||||
// We know the destructor has not run yet because we hold a reference to the userdata.
|
||||
// Check that userdata is not destructed (via `take()` call)
|
||||
if ud.lua.push_userdata_ref(&ud).is_err() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Clear uservalue
|
||||
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
|
||||
|
@ -404,9 +405,10 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
|
|||
let _sg = StackGuard::new(state);
|
||||
assert_stack(state, 2);
|
||||
|
||||
ud.lua.push_ref(&ud);
|
||||
|
||||
// We know the destructor has not run yet because we hold a reference to the userdata.
|
||||
// Check that userdata is valid (very likely)
|
||||
if ud.lua.push_userdata_ref(&ud).is_err() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Deregister metatable
|
||||
ffi::lua_getmetatable(state, -1);
|
||||
|
|
|
@ -11,7 +11,6 @@ use std::future::Future;
|
|||
#[cfg(feature = "serialize")]
|
||||
use {
|
||||
serde::ser::{self, Serialize, Serializer},
|
||||
std::os::raw::c_void,
|
||||
std::result::Result as StdResult,
|
||||
};
|
||||
|
||||
|
@ -21,7 +20,7 @@ use crate::function::Function;
|
|||
use crate::lua::Lua;
|
||||
use crate::table::{Table, TablePairs};
|
||||
use crate::types::{Callback, LuaRef, MaybeSend};
|
||||
use crate::util::{check_stack, get_userdata, StackGuard};
|
||||
use crate::util::{check_stack, get_userdata, take_userdata, StackGuard};
|
||||
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti};
|
||||
|
||||
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
|
||||
|
@ -626,18 +625,24 @@ impl<T> UserDataCell<T> {
|
|||
.map(|r| RefMut::map(r, |r| r.deref_mut()))
|
||||
.map_err(|_| Error::UserDataBorrowMutError)
|
||||
}
|
||||
|
||||
// Consumes this `UserDataCell`, returning the wrapped value.
|
||||
#[inline]
|
||||
fn into_inner(self) -> T {
|
||||
self.0.into_inner().into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum UserDataWrapped<T> {
|
||||
Default(T),
|
||||
Default(Box<T>),
|
||||
#[cfg(feature = "serialize")]
|
||||
Serializable(*mut T, *const dyn erased_serde::Serialize),
|
||||
Serializable(Box<dyn erased_serde::Serialize>),
|
||||
}
|
||||
|
||||
impl<T> UserDataWrapped<T> {
|
||||
#[inline]
|
||||
fn new(data: T) -> Self {
|
||||
UserDataWrapped::Default(data)
|
||||
UserDataWrapped::Default(Box::new(data))
|
||||
}
|
||||
|
||||
#[cfg(feature = "serialize")]
|
||||
|
@ -646,16 +651,15 @@ impl<T> UserDataWrapped<T> {
|
|||
where
|
||||
T: 'static + Serialize,
|
||||
{
|
||||
let data_raw = Box::into_raw(Box::new(data));
|
||||
UserDataWrapped::Serializable(data_raw, data_raw)
|
||||
UserDataWrapped::Serializable(Box::new(data))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serialize")]
|
||||
impl<T> Drop for UserDataWrapped<T> {
|
||||
fn drop(&mut self) {
|
||||
if let UserDataWrapped::Serializable(data, _) = *self {
|
||||
drop(unsafe { Box::from_raw(data) });
|
||||
#[inline]
|
||||
fn into_inner(self) -> T {
|
||||
match self {
|
||||
Self::Default(data) => *data,
|
||||
#[cfg(feature = "serialize")]
|
||||
Self::Serializable(data) => unsafe { *Box::from_raw(Box::into_raw(data) as *mut T) },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -668,7 +672,9 @@ impl<T> Deref for UserDataWrapped<T> {
|
|||
match self {
|
||||
Self::Default(data) => data,
|
||||
#[cfg(feature = "serialize")]
|
||||
Self::Serializable(data, _) => unsafe { &**data },
|
||||
Self::Serializable(data) => unsafe {
|
||||
&*(data.as_ref() as *const _ as *const Self::Target)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -679,7 +685,9 @@ impl<T> DerefMut for UserDataWrapped<T> {
|
|||
match self {
|
||||
Self::Default(data) => data,
|
||||
#[cfg(feature = "serialize")]
|
||||
Self::Serializable(data, _) => unsafe { &mut **data },
|
||||
Self::Serializable(data) => unsafe {
|
||||
&mut *(data.as_mut() as *mut _ as *mut Self::Target)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -748,6 +756,35 @@ impl<'lua> AnyUserData<'lua> {
|
|||
self.inspect(|cell| cell.try_borrow_mut())
|
||||
}
|
||||
|
||||
/// Takes out the value of `UserData` and sets the special "destructed" metatable that prevents
|
||||
/// any further operations with this userdata.
|
||||
#[doc(hidden)]
|
||||
pub fn take<T: 'static + UserData>(&self) -> Result<T> {
|
||||
let lua = self.0.lua;
|
||||
unsafe {
|
||||
let _sg = StackGuard::new(lua.state);
|
||||
check_stack(lua.state, 2)?;
|
||||
|
||||
let type_id = lua.push_userdata_ref(&self.0)?;
|
||||
match type_id {
|
||||
Some(type_id) if type_id == TypeId::of::<T>() => {
|
||||
// Try to borrow userdata exclusively
|
||||
let _ = (*get_userdata::<UserDataCell<T>>(lua.state, -1)).try_borrow_mut()?;
|
||||
|
||||
// Clear uservalue
|
||||
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
|
||||
ffi::lua_pushnil(lua.state);
|
||||
#[cfg(any(feature = "lua51", feature = "luajit"))]
|
||||
protect_lua!(lua.state, 0, 1, fn(state) ffi::lua_newtable(state))?;
|
||||
ffi::lua_setuservalue(lua.state, -2);
|
||||
|
||||
Ok(take_userdata::<UserDataCell<T>>(lua.state).into_inner())
|
||||
}
|
||||
_ => Err(Error::UserDataTypeMismatch),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets an associated value to this `AnyUserData`.
|
||||
///
|
||||
/// The value may be any Lua value whatsoever, and can be retrieved with [`get_user_value`].
|
||||
|
@ -960,20 +997,19 @@ impl<'lua> Serialize for AnyUserData<'lua> {
|
|||
where
|
||||
S: Serializer,
|
||||
{
|
||||
unsafe {
|
||||
let lua = self.0.lua;
|
||||
let lua = self.0.lua;
|
||||
let data = unsafe {
|
||||
let _sg = StackGuard::new(lua.state);
|
||||
check_stack(lua.state, 3).map_err(ser::Error::custom)?;
|
||||
|
||||
lua.push_userdata_ref(&self.0).map_err(ser::Error::custom)?;
|
||||
let ud = &*get_userdata::<UserDataCell<c_void>>(lua.state, -1);
|
||||
let data =
|
||||
ud.0.try_borrow()
|
||||
.map_err(|_| ser::Error::custom(Error::UserDataBorrowError))?;
|
||||
match *data {
|
||||
UserDataWrapped::Default(_) => UserDataSerializeError.serialize(serializer),
|
||||
UserDataWrapped::Serializable(_, ser) => (&*ser).serialize(serializer),
|
||||
}
|
||||
let ud = &*get_userdata::<UserDataCell<()>>(lua.state, -1);
|
||||
ud.0.try_borrow()
|
||||
.map_err(|_| ser::Error::custom(Error::UserDataBorrowError))?
|
||||
};
|
||||
match &*data {
|
||||
UserDataWrapped::Default(_) => UserDataSerializeError.serialize(serializer),
|
||||
UserDataWrapped::Serializable(ser) => ser.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -805,6 +805,7 @@ pub unsafe fn init_error_registry(state: *mut ffi::lua_State) -> Result<()> {
|
|||
// Create destructed userdata metatable
|
||||
|
||||
unsafe extern "C" fn destructed_error(state: *mut ffi::lua_State) -> c_int {
|
||||
// TODO: Consider changing error to UserDataDestructed in v0.7
|
||||
callback_error(state, |_| Err(Error::CallbackDestructed))
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ fn test_user_data() -> Result<()> {
|
|||
|
||||
#[test]
|
||||
fn test_methods() -> Result<()> {
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
struct MyUserData(i64);
|
||||
|
||||
impl UserData for MyUserData {
|
||||
|
@ -48,29 +49,38 @@ fn test_methods() -> Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
let lua = Lua::new();
|
||||
let globals = lua.globals();
|
||||
let userdata = lua.create_userdata(MyUserData(42))?;
|
||||
globals.set("userdata", userdata.clone())?;
|
||||
lua.load(
|
||||
r#"
|
||||
function get_it()
|
||||
return userdata:get_value()
|
||||
end
|
||||
fn check_methods(lua: &Lua, userdata: AnyUserData) -> Result<()> {
|
||||
let globals = lua.globals();
|
||||
globals.set("userdata", userdata.clone())?;
|
||||
lua.load(
|
||||
r#"
|
||||
function get_it()
|
||||
return userdata:get_value()
|
||||
end
|
||||
|
||||
function set_it(i)
|
||||
return userdata:set_value(i)
|
||||
end
|
||||
"#,
|
||||
)
|
||||
.exec()?;
|
||||
let get = globals.get::<_, Function>("get_it")?;
|
||||
let set = globals.get::<_, Function>("set_it")?;
|
||||
assert_eq!(get.call::<_, i64>(())?, 42);
|
||||
userdata.borrow_mut::<MyUserData>()?.0 = 64;
|
||||
assert_eq!(get.call::<_, i64>(())?, 64);
|
||||
set.call::<_, ()>(100)?;
|
||||
assert_eq!(get.call::<_, i64>(())?, 100);
|
||||
function set_it(i)
|
||||
return userdata:set_value(i)
|
||||
end
|
||||
"#,
|
||||
)
|
||||
.exec()?;
|
||||
let get = globals.get::<_, Function>("get_it")?;
|
||||
let set = globals.get::<_, Function>("set_it")?;
|
||||
assert_eq!(get.call::<_, i64>(())?, 42);
|
||||
userdata.borrow_mut::<MyUserData>()?.0 = 64;
|
||||
assert_eq!(get.call::<_, i64>(())?, 64);
|
||||
set.call::<_, ()>(100)?;
|
||||
assert_eq!(get.call::<_, i64>(())?, 100);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let lua = Lua::new();
|
||||
|
||||
check_methods(&lua, lua.create_userdata(MyUserData(42))?)?;
|
||||
|
||||
// Additionally check serializable userdata
|
||||
#[cfg(feature = "serialize")]
|
||||
check_methods(&lua, lua.create_ser_userdata(MyUserData(42))?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -252,6 +262,76 @@ fn test_gc_userdata() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_userdata_take() -> Result<()> {
|
||||
#[derive(Debug)]
|
||||
struct MyUserdata(Arc<i64>);
|
||||
|
||||
impl UserData for MyUserdata {
|
||||
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_method("num", |_, this, ()| Ok(*this.0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serialize")]
|
||||
impl serde::Serialize for MyUserdata {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_i64(*self.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn check_userdata_take(lua: &Lua, userdata: AnyUserData, rc: Arc<i64>) -> Result<()> {
|
||||
lua.globals().set("userdata", userdata.clone())?;
|
||||
assert_eq!(Arc::strong_count(&rc), 2);
|
||||
let userdata_copy = userdata.clone();
|
||||
{
|
||||
let _value = userdata.borrow::<MyUserdata>()?;
|
||||
// We should not be able to take userdata if it's borrowed
|
||||
match userdata_copy.take::<MyUserdata>() {
|
||||
Err(Error::UserDataBorrowMutError) => {}
|
||||
r => panic!("expected `UserDataBorrowMutError` error, got {:?}", r),
|
||||
}
|
||||
}
|
||||
|
||||
let value = userdata_copy.take::<MyUserdata>()?;
|
||||
assert_eq!(*value.0, 18);
|
||||
drop(value);
|
||||
assert_eq!(Arc::strong_count(&rc), 1);
|
||||
|
||||
match userdata.borrow::<MyUserdata>() {
|
||||
Err(Error::UserDataDestructed) => {}
|
||||
r => panic!("expected `UserDataDestructed` error, got {:?}", r),
|
||||
}
|
||||
match lua.load("userdata:num()").exec() {
|
||||
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
|
||||
Error::CallbackDestructed => {}
|
||||
err => panic!("expected `CallbackDestructed`, got {:?}", err),
|
||||
},
|
||||
r => panic!("improper return for destructed userdata: {:?}", r),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let lua = Lua::new();
|
||||
|
||||
let rc = Arc::new(18);
|
||||
let userdata = lua.create_userdata(MyUserdata(rc.clone()))?;
|
||||
check_userdata_take(&lua, userdata, rc)?;
|
||||
|
||||
// Additionally check serializable userdata
|
||||
#[cfg(feature = "serialize")]
|
||||
{
|
||||
let rc = Arc::new(18);
|
||||
let userdata = lua.create_ser_userdata(MyUserdata(rc.clone()))?;
|
||||
check_userdata_take(&lua, userdata, rc)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_destroy_userdata() -> Result<()> {
|
||||
struct MyUserdata(Arc<()>);
|
||||
|
|
Loading…
Reference in New Issue