From 1635903d3f3d3a5adc12a7ff93986844b26784b0 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sat, 20 Feb 2021 01:02:35 +0000 Subject: [PATCH] Improve/fix scoped UserData drop --- src/scope.rs | 73 +++++++++++-------- src/userdata.rs | 1 + tests/scope.rs | 182 ++++++++++++++++++++++++++++-------------------- 3 files changed, 150 insertions(+), 106 deletions(-) diff --git a/src/scope.rs b/src/scope.rs index 52809b5..6ebd521 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -186,6 +186,14 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { let state = u.lua.state; assert_stack(state, 2); u.lua.push_ref(&u); + + // Clear uservalue + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] + ffi::lua_pushnil(state); + #[cfg(any(feature = "lua51", feature = "luajit"))] + ffi::lua_newtable(state); + ffi::lua_setuservalue(state, -2); + // We know the destructor has not run yet because we hold a reference to the // userdata. vec![Box::new(take_userdata::>(state))] @@ -244,28 +252,28 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { let check_ud_type = move |lua: &'callback Lua, value| { if let Some(Value::UserData(ud)) = value { unsafe { - assert_stack(lua.state, 1); + let _sg = StackGuard::new(lua.state); + assert_stack(lua.state, 3); lua.push_ref(&ud.0); - ffi::lua_getuservalue(lua.state, -1); - #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] - { - ffi::lua_rawgeti(lua.state, -1, 1); - ffi::lua_remove(lua.state, -2); + if ffi::lua_getmetatable(lua.state, -1) == 0 { + return Err(Error::UserDataTypeMismatch); + } + ffi::lua_pushstring(lua.state, cstr!("__mlua")); + if ffi::lua_rawget(lua.state, -2) == ffi::LUA_TLIGHTUSERDATA { + let ud_ptr = ffi::lua_touserdata(lua.state, -1); + if ud_ptr == check_data.as_ptr() as *mut c_void { + return Ok(()); + } } - return ffi::lua_touserdata(lua.state, -1) - == check_data.as_ptr() as *mut c_void; } - } - - false + }; + Err(Error::UserDataTypeMismatch) }; match method { NonStaticMethod::Method(method) => { let f = Box::new(move |lua, mut args: MultiValue<'callback>| { - if !check_ud_type(lua, args.pop_front()) { - return Err(Error::UserDataTypeMismatch); - } + check_ud_type(lua, args.pop_front())?; let data = data .try_borrow() .map(|cell| Ref::map(cell, AsRef::as_ref)) @@ -277,9 +285,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { NonStaticMethod::MethodMut(method) => { let method = RefCell::new(method); let f = Box::new(move |lua, mut args: MultiValue<'callback>| { - if !check_ud_type(lua, args.pop_front()) { - return Err(Error::UserDataTypeMismatch); - } + check_ud_type(lua, args.pop_front())?; let mut method = method .try_borrow_mut() .map_err(|_| Error::RecursiveMutCallback)?; @@ -314,24 +320,19 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { unsafe { let lua = self.lua; let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 6); + assert_stack(lua.state, 13); // We need to wrap dummy userdata because their memory can be accessed by serializer push_userdata(lua.state, UserDataCell::new(UserDataWrapped::new(())))?; - #[cfg(any(feature = "lua54", feature = "lua53"))] - ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void); - #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] - protect_lua_closure(lua.state, 0, 1, |state| { - // Lua 5.2/5.1 allows to store only table. Then we will wrap the value. - ffi::lua_createtable(state, 1, 0); - ffi::lua_pushlightuserdata(state, data.as_ptr() as *mut c_void); - ffi::lua_rawseti(state, -2, 1); - })?; - ffi::lua_setuservalue(lua.state, -2); // Prepare metatable, add meta methods first and then meta fields - protect_lua_closure(lua.state, 0, 1, move |state| { + protect_lua_closure(lua.state, 0, 1, |state| { ffi::lua_newtable(state); + + // Add internal metamethod to store reference to the data + ffi::lua_pushstring(state, cstr!("__mlua")); + ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void); + ffi::lua_rawset(state, -3); })?; for (k, m) in ud_methods.meta_methods { push_string(lua.state, k.validate()?.name())?; @@ -415,19 +416,31 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { let mt_id = ffi::lua_topointer(lua.state, -1); ffi::lua_setmetatable(lua.state, -2); - let ud = AnyUserData(lua.pop_ref()); lua.register_userdata_metatable(mt_id as isize); + self.destructors.borrow_mut().push((ud.0.clone(), |ud| { + // We know the destructor has not run yet because we hold a reference to the userdata. let state = ud.lua.state; assert_stack(state, 2); ud.lua.push_ref(&ud); + + // Deregister metatable ffi::lua_getmetatable(state, -1); let mt_id = ffi::lua_topointer(state, -1); ffi::lua_pop(state, 1); ud.lua.deregister_userdata_metatable(mt_id as isize); + + // Clear uservalue + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] + ffi::lua_pushnil(state); + #[cfg(any(feature = "lua51", feature = "luajit"))] + ffi::lua_newtable(state); + ffi::lua_setuservalue(state, -2); + vec![Box::new(take_userdata::>(state))] })); + Ok(ud) } } diff --git a/src/userdata.rs b/src/userdata.rs index e4b4397..d28a374 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -186,6 +186,7 @@ impl MetaMethod { MetaMethod::Custom(name) if name == "__metatable" => { Err(Error::MetaMethodRestricted(name)) } + MetaMethod::Custom(name) if name == "__mlua" => Err(Error::MetaMethodRestricted(name)), _ => Ok(self), } } diff --git a/tests/scope.rs b/tests/scope.rs index 95591be..7c29bff 100644 --- a/tests/scope.rs +++ b/tests/scope.rs @@ -1,5 +1,6 @@ use std::cell::Cell; use std::rc::Rc; +use std::sync::Arc; use mlua::{ AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataFields, @@ -26,87 +27,16 @@ fn scope_func() -> Result<()> { assert_eq!(Rc::strong_count(&rc), 1); match lua.globals().get::<_, Function>("bad")?.call::<_, ()>(()) { - Err(Error::CallbackError { .. }) => {} + Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { + Error::CallbackDestructed => {} + ref err => panic!("wrong error type {:?}", err), + }, r => panic!("improper return for destructed function: {:?}", r), }; Ok(()) } -#[test] -fn scope_drop() -> Result<()> { - let lua = Lua::new(); - - struct MyUserdata(Rc<()>); - impl UserData for MyUserdata { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("method", |_, _, ()| Ok(())); - } - } - - let rc = Rc::new(()); - - lua.scope(|scope| { - lua.globals() - .set("static_ud", scope.create_userdata(MyUserdata(rc.clone()))?)?; - assert_eq!(Rc::strong_count(&rc), 2); - Ok(()) - })?; - assert_eq!(Rc::strong_count(&rc), 1); - - match lua.load("static_ud:method()").exec() { - Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - e => panic!("expected CallbackDestructed, got {:?}", e), - }, - r => panic!("improper return for destructed userdata: {:?}", r), - }; - - let static_ud = lua.globals().get::<_, AnyUserData>("static_ud")?; - match static_ud.borrow::() { - Ok(_) => panic!("borrowed destructed userdata"), - Err(Error::UserDataDestructed) => {} - Err(e) => panic!("expected UserDataDestructed, got {:?}", e), - } - - // Check non-static UserData drop - struct MyUserDataRef<'a>(&'a Cell); - - impl<'a> UserData for MyUserDataRef<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("inc", |_, data, ()| { - data.0.set(data.0.get() + 1); - Ok(()) - }); - } - } - - let i = Cell::new(1); - lua.scope(|scope| { - lua.globals().set( - "nonstatic_ud", - scope.create_nonstatic_userdata(MyUserDataRef(&i))?, - ) - })?; - - match lua.load("nonstatic_ud:inc(1)").exec() { - Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - e => panic!("expected CallbackDestructed, got {:?}", e), - }, - r => panic!("improper return for destructed userdata: {:?}", r), - }; - - let nonstatic_ud = lua.globals().get::<_, AnyUserData>("nonstatic_ud")?; - match nonstatic_ud.borrow::() { - Ok(_) => panic!("borrowed destructed userdata"), - Err(Error::UserDataDestructed) => {} - Err(e) => panic!("expected UserDataDestructed, got {:?}", e), - } - - Ok(()) -} - #[test] fn scope_capture() -> Result<()> { let lua = Lua::new(); @@ -126,7 +56,7 @@ fn scope_capture() -> Result<()> { } #[test] -fn outer_lua_access() -> Result<()> { +fn scope_outer_lua_access() -> Result<()> { let lua = Lua::new(); let table = lua.create_table()?; @@ -309,3 +239,103 @@ fn scope_userdata_mismatch() -> Result<()> { Ok(()) } + +#[test] +fn scope_userdata_drop() -> Result<()> { + let lua = Lua::new(); + + struct MyUserData(Rc<()>); + + impl UserData for MyUserData { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("method", |_, _, ()| Ok(())); + } + } + + struct MyUserDataArc(Arc<()>); + + impl UserData for MyUserDataArc {} + + let rc = Rc::new(()); + let arc = Arc::new(()); + lua.scope(|scope| { + let ud = scope.create_userdata(MyUserData(rc.clone()))?; + ud.set_user_value(MyUserDataArc(arc.clone()))?; + lua.globals().set("ud", ud)?; + assert_eq!(Rc::strong_count(&rc), 2); + assert_eq!(Arc::strong_count(&arc), 2); + Ok(()) + })?; + + lua.gc_collect()?; + assert_eq!(Rc::strong_count(&rc), 1); + assert_eq!(Arc::strong_count(&arc), 1); + + match lua.load("ud:method()").exec() { + Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { + Error::CallbackDestructed => {} + err => panic!("expected CallbackDestructed, got {:?}", err), + }, + r => panic!("improper return for destructed userdata: {:?}", r), + }; + + let ud = lua.globals().get::<_, AnyUserData>("ud")?; + match ud.borrow::() { + Ok(_) => panic!("succesfull borrow for destructed userdata"), + Err(Error::UserDataDestructed) => {} + Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err), + } + + Ok(()) +} + +#[test] +fn scope_nonstatic_userdata_drop() -> Result<()> { + let lua = Lua::new(); + + struct MyUserData<'a>(&'a Cell); + + impl<'a> UserData for MyUserData<'a> { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("inc", |_, data, ()| { + data.0.set(data.0.get() + 1); + Ok(()) + }); + } + } + + struct MyUserDataArc(Arc<()>); + + impl UserData for MyUserDataArc {} + + let i = Cell::new(1); + let arc = Arc::new(()); + lua.scope(|scope| { + let ud = scope.create_nonstatic_userdata(MyUserData(&i))?; + ud.set_user_value(MyUserDataArc(arc.clone()))?; + lua.globals().set("ud", ud)?; + lua.load("ud:inc()").exec()?; + assert_eq!(Arc::strong_count(&arc), 2); + Ok(()) + })?; + + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&arc), 1); + + match lua.load("ud:inc()").exec() { + Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { + Error::CallbackDestructed => {} + err => panic!("expected CallbackDestructed, got {:?}", err), + }, + r => panic!("improper return for destructed userdata: {:?}", r), + }; + + let ud = lua.globals().get::<_, AnyUserData>("ud")?; + match ud.borrow::() { + Ok(_) => panic!("succesfull borrow for destructed userdata"), + Err(Error::UserDataDestructed) => {} + Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err), + } + + Ok(()) +}