Improve/fix scoped UserData drop

This commit is contained in:
Alex Orlenko 2021-02-20 01:02:35 +00:00
parent 2b2df708f9
commit 1635903d3f
3 changed files with 150 additions and 106 deletions

View File

@ -186,6 +186,14 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let state = u.lua.state; let state = u.lua.state;
assert_stack(state, 2); assert_stack(state, 2);
u.lua.push_ref(&u); u.lua.push_ref(&u);
// Clear uservalue
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
ffi::lua_pushnil(state);
#[cfg(any(feature = "lua51", feature = "luajit"))]
ffi::lua_newtable(state);
ffi::lua_setuservalue(state, -2);
// We know the destructor has not run yet because we hold a reference to the // We know the destructor has not run yet because we hold a reference to the
// userdata. // userdata.
vec![Box::new(take_userdata::<UserDataCell<T>>(state))] vec![Box::new(take_userdata::<UserDataCell<T>>(state))]
@ -244,28 +252,28 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let check_ud_type = move |lua: &'callback Lua, value| { let check_ud_type = move |lua: &'callback Lua, value| {
if let Some(Value::UserData(ud)) = value { if let Some(Value::UserData(ud)) = value {
unsafe { unsafe {
assert_stack(lua.state, 1); let _sg = StackGuard::new(lua.state);
assert_stack(lua.state, 3);
lua.push_ref(&ud.0); lua.push_ref(&ud.0);
ffi::lua_getuservalue(lua.state, -1); if ffi::lua_getmetatable(lua.state, -1) == 0 {
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] return Err(Error::UserDataTypeMismatch);
{ }
ffi::lua_rawgeti(lua.state, -1, 1); ffi::lua_pushstring(lua.state, cstr!("__mlua"));
ffi::lua_remove(lua.state, -2); if ffi::lua_rawget(lua.state, -2) == ffi::LUA_TLIGHTUSERDATA {
let ud_ptr = ffi::lua_touserdata(lua.state, -1);
if ud_ptr == check_data.as_ptr() as *mut c_void {
return Ok(());
}
} }
return ffi::lua_touserdata(lua.state, -1)
== check_data.as_ptr() as *mut c_void;
} }
} };
Err(Error::UserDataTypeMismatch)
false
}; };
match method { match method {
NonStaticMethod::Method(method) => { NonStaticMethod::Method(method) => {
let f = Box::new(move |lua, mut args: MultiValue<'callback>| { let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
if !check_ud_type(lua, args.pop_front()) { check_ud_type(lua, args.pop_front())?;
return Err(Error::UserDataTypeMismatch);
}
let data = data let data = data
.try_borrow() .try_borrow()
.map(|cell| Ref::map(cell, AsRef::as_ref)) .map(|cell| Ref::map(cell, AsRef::as_ref))
@ -277,9 +285,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
NonStaticMethod::MethodMut(method) => { NonStaticMethod::MethodMut(method) => {
let method = RefCell::new(method); let method = RefCell::new(method);
let f = Box::new(move |lua, mut args: MultiValue<'callback>| { let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
if !check_ud_type(lua, args.pop_front()) { check_ud_type(lua, args.pop_front())?;
return Err(Error::UserDataTypeMismatch);
}
let mut method = method let mut method = method
.try_borrow_mut() .try_borrow_mut()
.map_err(|_| Error::RecursiveMutCallback)?; .map_err(|_| Error::RecursiveMutCallback)?;
@ -314,24 +320,19 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
unsafe { unsafe {
let lua = self.lua; let lua = self.lua;
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
assert_stack(lua.state, 6); assert_stack(lua.state, 13);
// We need to wrap dummy userdata because their memory can be accessed by serializer // We need to wrap dummy userdata because their memory can be accessed by serializer
push_userdata(lua.state, UserDataCell::new(UserDataWrapped::new(())))?; push_userdata(lua.state, UserDataCell::new(UserDataWrapped::new(())))?;
#[cfg(any(feature = "lua54", feature = "lua53"))]
ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void);
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
protect_lua_closure(lua.state, 0, 1, |state| {
// Lua 5.2/5.1 allows to store only table. Then we will wrap the value.
ffi::lua_createtable(state, 1, 0);
ffi::lua_pushlightuserdata(state, data.as_ptr() as *mut c_void);
ffi::lua_rawseti(state, -2, 1);
})?;
ffi::lua_setuservalue(lua.state, -2);
// Prepare metatable, add meta methods first and then meta fields // Prepare metatable, add meta methods first and then meta fields
protect_lua_closure(lua.state, 0, 1, move |state| { protect_lua_closure(lua.state, 0, 1, |state| {
ffi::lua_newtable(state); ffi::lua_newtable(state);
// Add internal metamethod to store reference to the data
ffi::lua_pushstring(state, cstr!("__mlua"));
ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void);
ffi::lua_rawset(state, -3);
})?; })?;
for (k, m) in ud_methods.meta_methods { for (k, m) in ud_methods.meta_methods {
push_string(lua.state, k.validate()?.name())?; push_string(lua.state, k.validate()?.name())?;
@ -415,19 +416,31 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
let mt_id = ffi::lua_topointer(lua.state, -1); let mt_id = ffi::lua_topointer(lua.state, -1);
ffi::lua_setmetatable(lua.state, -2); ffi::lua_setmetatable(lua.state, -2);
let ud = AnyUserData(lua.pop_ref()); let ud = AnyUserData(lua.pop_ref());
lua.register_userdata_metatable(mt_id as isize); lua.register_userdata_metatable(mt_id as isize);
self.destructors.borrow_mut().push((ud.0.clone(), |ud| { self.destructors.borrow_mut().push((ud.0.clone(), |ud| {
// We know the destructor has not run yet because we hold a reference to the userdata.
let state = ud.lua.state; let state = ud.lua.state;
assert_stack(state, 2); assert_stack(state, 2);
ud.lua.push_ref(&ud); ud.lua.push_ref(&ud);
// Deregister metatable
ffi::lua_getmetatable(state, -1); ffi::lua_getmetatable(state, -1);
let mt_id = ffi::lua_topointer(state, -1); let mt_id = ffi::lua_topointer(state, -1);
ffi::lua_pop(state, 1); ffi::lua_pop(state, 1);
ud.lua.deregister_userdata_metatable(mt_id as isize); ud.lua.deregister_userdata_metatable(mt_id as isize);
// Clear uservalue
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
ffi::lua_pushnil(state);
#[cfg(any(feature = "lua51", feature = "luajit"))]
ffi::lua_newtable(state);
ffi::lua_setuservalue(state, -2);
vec![Box::new(take_userdata::<UserDataCell<()>>(state))] vec![Box::new(take_userdata::<UserDataCell<()>>(state))]
})); }));
Ok(ud) Ok(ud)
} }
} }

View File

@ -186,6 +186,7 @@ impl MetaMethod {
MetaMethod::Custom(name) if name == "__metatable" => { MetaMethod::Custom(name) if name == "__metatable" => {
Err(Error::MetaMethodRestricted(name)) Err(Error::MetaMethodRestricted(name))
} }
MetaMethod::Custom(name) if name == "__mlua" => Err(Error::MetaMethodRestricted(name)),
_ => Ok(self), _ => Ok(self),
} }
} }

View File

@ -1,5 +1,6 @@
use std::cell::Cell; use std::cell::Cell;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc;
use mlua::{ use mlua::{
AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataFields, AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataFields,
@ -26,87 +27,16 @@ fn scope_func() -> Result<()> {
assert_eq!(Rc::strong_count(&rc), 1); assert_eq!(Rc::strong_count(&rc), 1);
match lua.globals().get::<_, Function>("bad")?.call::<_, ()>(()) { match lua.globals().get::<_, Function>("bad")?.call::<_, ()>(()) {
Err(Error::CallbackError { .. }) => {} Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() {
Error::CallbackDestructed => {}
ref err => panic!("wrong error type {:?}", err),
},
r => panic!("improper return for destructed function: {:?}", r), r => panic!("improper return for destructed function: {:?}", r),
}; };
Ok(()) Ok(())
} }
#[test]
fn scope_drop() -> Result<()> {
let lua = Lua::new();
struct MyUserdata(Rc<()>);
impl UserData for MyUserdata {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("method", |_, _, ()| Ok(()));
}
}
let rc = Rc::new(());
lua.scope(|scope| {
lua.globals()
.set("static_ud", scope.create_userdata(MyUserdata(rc.clone()))?)?;
assert_eq!(Rc::strong_count(&rc), 2);
Ok(())
})?;
assert_eq!(Rc::strong_count(&rc), 1);
match lua.load("static_ud:method()").exec() {
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
Error::CallbackDestructed => {}
e => panic!("expected CallbackDestructed, got {:?}", e),
},
r => panic!("improper return for destructed userdata: {:?}", r),
};
let static_ud = lua.globals().get::<_, AnyUserData>("static_ud")?;
match static_ud.borrow::<MyUserdata>() {
Ok(_) => panic!("borrowed destructed userdata"),
Err(Error::UserDataDestructed) => {}
Err(e) => panic!("expected UserDataDestructed, got {:?}", e),
}
// Check non-static UserData drop
struct MyUserDataRef<'a>(&'a Cell<i64>);
impl<'a> UserData for MyUserDataRef<'a> {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("inc", |_, data, ()| {
data.0.set(data.0.get() + 1);
Ok(())
});
}
}
let i = Cell::new(1);
lua.scope(|scope| {
lua.globals().set(
"nonstatic_ud",
scope.create_nonstatic_userdata(MyUserDataRef(&i))?,
)
})?;
match lua.load("nonstatic_ud:inc(1)").exec() {
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
Error::CallbackDestructed => {}
e => panic!("expected CallbackDestructed, got {:?}", e),
},
r => panic!("improper return for destructed userdata: {:?}", r),
};
let nonstatic_ud = lua.globals().get::<_, AnyUserData>("nonstatic_ud")?;
match nonstatic_ud.borrow::<MyUserDataRef>() {
Ok(_) => panic!("borrowed destructed userdata"),
Err(Error::UserDataDestructed) => {}
Err(e) => panic!("expected UserDataDestructed, got {:?}", e),
}
Ok(())
}
#[test] #[test]
fn scope_capture() -> Result<()> { fn scope_capture() -> Result<()> {
let lua = Lua::new(); let lua = Lua::new();
@ -126,7 +56,7 @@ fn scope_capture() -> Result<()> {
} }
#[test] #[test]
fn outer_lua_access() -> Result<()> { fn scope_outer_lua_access() -> Result<()> {
let lua = Lua::new(); let lua = Lua::new();
let table = lua.create_table()?; let table = lua.create_table()?;
@ -309,3 +239,103 @@ fn scope_userdata_mismatch() -> Result<()> {
Ok(()) Ok(())
} }
#[test]
fn scope_userdata_drop() -> Result<()> {
let lua = Lua::new();
struct MyUserData(Rc<()>);
impl UserData for MyUserData {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("method", |_, _, ()| Ok(()));
}
}
struct MyUserDataArc(Arc<()>);
impl UserData for MyUserDataArc {}
let rc = Rc::new(());
let arc = Arc::new(());
lua.scope(|scope| {
let ud = scope.create_userdata(MyUserData(rc.clone()))?;
ud.set_user_value(MyUserDataArc(arc.clone()))?;
lua.globals().set("ud", ud)?;
assert_eq!(Rc::strong_count(&rc), 2);
assert_eq!(Arc::strong_count(&arc), 2);
Ok(())
})?;
lua.gc_collect()?;
assert_eq!(Rc::strong_count(&rc), 1);
assert_eq!(Arc::strong_count(&arc), 1);
match lua.load("ud:method()").exec() {
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
Error::CallbackDestructed => {}
err => panic!("expected CallbackDestructed, got {:?}", err),
},
r => panic!("improper return for destructed userdata: {:?}", r),
};
let ud = lua.globals().get::<_, AnyUserData>("ud")?;
match ud.borrow::<MyUserData>() {
Ok(_) => panic!("succesfull borrow for destructed userdata"),
Err(Error::UserDataDestructed) => {}
Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err),
}
Ok(())
}
#[test]
fn scope_nonstatic_userdata_drop() -> Result<()> {
let lua = Lua::new();
struct MyUserData<'a>(&'a Cell<i64>);
impl<'a> UserData for MyUserData<'a> {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("inc", |_, data, ()| {
data.0.set(data.0.get() + 1);
Ok(())
});
}
}
struct MyUserDataArc(Arc<()>);
impl UserData for MyUserDataArc {}
let i = Cell::new(1);
let arc = Arc::new(());
lua.scope(|scope| {
let ud = scope.create_nonstatic_userdata(MyUserData(&i))?;
ud.set_user_value(MyUserDataArc(arc.clone()))?;
lua.globals().set("ud", ud)?;
lua.load("ud:inc()").exec()?;
assert_eq!(Arc::strong_count(&arc), 2);
Ok(())
})?;
lua.gc_collect()?;
assert_eq!(Arc::strong_count(&arc), 1);
match lua.load("ud:inc()").exec() {
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
Error::CallbackDestructed => {}
err => panic!("expected CallbackDestructed, got {:?}", err),
},
r => panic!("improper return for destructed userdata: {:?}", r),
};
let ud = lua.globals().get::<_, AnyUserData>("ud")?;
match ud.borrow::<MyUserData>() {
Ok(_) => panic!("succesfull borrow for destructed userdata"),
Err(Error::UserDataDestructed) => {}
Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err),
}
Ok(())
}