From 4daa7de9978d7dfe43c40e8081a94aaaacf48645 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Wed, 26 Apr 2023 15:34:15 +0100 Subject: [PATCH] Various improvements for owned types, including: - tests - shortcuts for `OwnedFunction` and `OwnedAnyUserData` --- .github/workflows/main.yml | 7 +++---- src/function.rs | 37 +++++++++++++++++++++++++++++++++ src/lua.rs | 31 ++++++++++------------------ src/table.rs | 6 ++++++ src/types.rs | 26 ++++++++++------------- src/userdata.rs | 30 +++++++++++++++++++++++++++ tests/async.rs | 24 +++++++++++++++++----- tests/function.rs | 42 ++++++++++++++++++++++++++++++++++++++ tests/table.rs | 14 +++++++++++++ tests/userdata.rs | 19 +++++++++++++++++ 10 files changed, 192 insertions(+), 44 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d3c8f19..5faaca0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -122,10 +122,9 @@ jobs: - name: Run ${{ matrix.lua }} tests run: | cargo test --features "${{ matrix.lua }},vendored" - cargo test --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot,unstable" + cargo test --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" + cargo test --features "${{ matrix.lua }},vendored,async,serialize,macros,parking_lot,unstable" shell: bash - env: - RUSTFLAGS: --cfg mlua_test - name: Run compile tests (macos lua54) if: ${{ matrix.os == 'macos-latest' && matrix.lua == 'lua54' }} run: | @@ -157,7 +156,7 @@ jobs: cargo test --tests --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot,unstable" --target x86_64-unknown-linux-gnu -- --skip test_too_many_recursions shell: bash env: - RUSTFLAGS: --cfg mlua_test -Z sanitizer=address + RUSTFLAGS: -Z sanitizer=address test_modules: name: Test modules diff --git a/src/function.rs b/src/function.rs index 813517b..85287f0 100644 --- a/src/function.rs +++ b/src/function.rs @@ -25,6 +25,12 @@ use { pub struct Function<'lua>(pub(crate) LuaRef<'lua>); /// Owned handle to an internal Lua function. +/// +/// The owned handle holds a *strong* reference to the current Lua instance. +/// Be warned, if you place it into a Lua type (eg. [`UserData`] or a Rust callback), it is *very easy* +/// to accidentally cause reference cycles that would prevent destroying Lua instance. +/// +/// [`UserData`]: crate::UserData #[cfg(feature = "unstable")] #[cfg_attr(docsrs, doc(cfg(feature = "unstable")))] #[derive(Clone, Debug)] @@ -415,6 +421,37 @@ impl<'lua> PartialEq for Function<'lua> { } } +// Additional shortcuts +#[cfg(feature = "unstable")] +impl OwnedFunction { + /// Calls the function, passing `args` as function arguments. + /// + /// This is a shortcut for [`Function::call()`]. + #[inline] + pub fn call<'lua, A, R>(&'lua self, args: A) -> Result + where + A: IntoLuaMulti<'lua>, + R: FromLuaMulti<'lua>, + { + self.to_ref().call(args) + } + + /// Returns a future that, when polled, calls `self`, passing `args` as function arguments, + /// and drives the execution. + /// + /// This is a shortcut for [`Function::call_async()`]. + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + #[inline] + pub fn call_async<'lua, A, R>(&'lua self, args: A) -> LocalBoxFuture<'lua, Result> + where + A: IntoLuaMulti<'lua>, + R: FromLuaMulti<'lua> + 'lua, + { + self.to_ref().call_async(args) + } +} + pub(crate) struct WrappedFunction<'lua>(pub(crate) Callback<'lua, 'static>); #[cfg(feature = "async")] diff --git a/src/lua.rs b/src/lua.rs index 319ce14..89ab6f1 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -230,30 +230,21 @@ const MULTIVALUE_POOL_SIZE: usize = 64; #[cfg_attr(docsrs, doc(cfg(feature = "send")))] unsafe impl Send for Lua {} +#[cfg(not(feature = "module"))] +impl Drop for Lua { + fn drop(&mut self) { + let _ = self.gc_collect(); + } +} + #[cfg(not(feature = "module"))] impl Drop for LuaInner { fn drop(&mut self) { unsafe { - let extra = &mut *self.extra.get(); - let drain_iter = extra.wrapped_failure_pool.drain(..); - #[cfg(feature = "async")] - let drain_iter = drain_iter.chain(extra.thread_pool.drain(..)); - for index in drain_iter { - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, index); - extra.ref_free.push(index); - } #[cfg(feature = "luau")] { (*ffi::lua_callbacks(self.state())).userdata = ptr::null_mut(); } - // This is an internal assertion used in integration tests - #[cfg(mlua_test)] - mlua_debug_assert!( - ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top - && extra.ref_stack_top as usize == extra.ref_free.len(), - "reference leak detected" - ); ffi::lua_close(self.main_state); } } @@ -2477,12 +2468,12 @@ impl Lua { #[cfg(all(feature = "unstable", not(feature = "send")))] pub(crate) fn adopt_owned_ref(&self, loref: crate::types::LuaOwnedRef) -> LuaRef { assert!( - Arc::ptr_eq(&loref.lua.0, &self.0), + Arc::ptr_eq(&loref.inner, &self.0), "Lua instance passed Value created from a different main Lua state" ); let index = loref.index; unsafe { - ptr::read(&loref.lua); + ptr::read(&loref.inner); mem::forget(loref); } LuaRef::new(self, index) @@ -3011,8 +3002,8 @@ impl Lua { #[cfg(feature = "unstable")] #[inline] - pub(crate) fn clone(&self) -> Self { - Lua(Arc::clone(&self.0)) + pub(crate) fn clone(&self) -> Arc { + Arc::clone(&self.0) } } diff --git a/src/table.rs b/src/table.rs index 13eeeb6..37cdae9 100644 --- a/src/table.rs +++ b/src/table.rs @@ -25,6 +25,12 @@ use {futures_core::future::LocalBoxFuture, futures_util::future}; pub struct Table<'lua>(pub(crate) LuaRef<'lua>); /// Owned handle to an internal Lua table. +/// +/// The owned handle holds a *strong* reference to the current Lua instance. +/// Be warned, if you place it into a Lua type (eg. [`UserData`] or a Rust callback), it is *very easy* +/// to accidentally cause reference cycles that would prevent destroying Lua instance. +/// +/// [`UserData`]: crate::UserData #[cfg(feature = "unstable")] #[cfg_attr(docsrs, doc(cfg(feature = "unstable")))] #[derive(Clone, Debug)] diff --git a/src/types.rs b/src/types.rs index e06529d..6af7a67 100644 --- a/src/types.rs +++ b/src/types.rs @@ -18,6 +18,9 @@ use crate::lua::{ExtraData, Lua}; use crate::util::{assert_stack, StackGuard}; use crate::value::MultiValue; +#[cfg(feature = "unstable")] +use {crate::lua::LuaInner, std::marker::PhantomData}; + /// Type of Lua integer numbers. pub type Integer = ffi::lua_Integer; /// Type of Lua floating point numbers. @@ -237,9 +240,9 @@ impl<'lua> PartialEq for LuaRef<'lua> { #[cfg(feature = "unstable")] pub(crate) struct LuaOwnedRef { - pub(crate) lua: Lua, + pub(crate) inner: Arc, pub(crate) index: c_int, - _non_send: std::marker::PhantomData<*const ()>, + _non_send: PhantomData<*const ()>, } #[cfg(feature = "unstable")] @@ -259,31 +262,24 @@ impl Clone for LuaOwnedRef { #[cfg(feature = "unstable")] impl Drop for LuaOwnedRef { fn drop(&mut self) { - self.lua.drop_ref_index(self.index); + let lua: &Lua = unsafe { mem::transmute(&self.inner) }; + lua.drop_ref_index(self.index); } } #[cfg(feature = "unstable")] impl LuaOwnedRef { - pub(crate) const fn new(lua: Lua, index: c_int) -> Self { - #[cfg(feature = "send")] - { - let _lua = lua; - let _index = index; - panic!("mlua must be compiled without \"send\" feature to use Owned types"); - } - - #[cfg(not(feature = "send"))] + pub(crate) const fn new(inner: Arc, index: c_int) -> Self { LuaOwnedRef { - lua, + inner, index, - _non_send: std::marker::PhantomData, + _non_send: PhantomData, } } pub(crate) const fn to_ref(&self) -> LuaRef { LuaRef { - lua: &self.lua, + lua: unsafe { mem::transmute(&self.inner) }, index: self.index, drop: false, } diff --git a/src/userdata.rs b/src/userdata.rs index d5770ca..f405074 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -740,6 +740,10 @@ impl Serialize for UserDataSerializeError { pub struct AnyUserData<'lua>(pub(crate) LuaRef<'lua>); /// Owned handle to an internal Lua userdata. +/// +/// The owned handle holds a *strong* reference to the current Lua instance. +/// Be warned, if you place it into a Lua type (eg. [`UserData`] or a Rust callback), it is *very easy* +/// to accidentally cause reference cycles that would prevent destroying Lua instance. #[cfg(feature = "unstable")] #[cfg_attr(docsrs, doc(cfg(feature = "unstable")))] #[derive(Clone, Debug)] @@ -1124,6 +1128,32 @@ unsafe fn getuservalue_table(state: *mut ffi::lua_State, idx: c_int) -> c_int { return ffi::lua_getuservalue(state, idx); } +// Additional shortcuts +#[cfg(feature = "unstable")] +impl OwnedAnyUserData { + /// Borrow this userdata immutably if it is of type `T`. + /// + /// This is a shortcut for [`AnyUserData::borrow()`] + #[inline] + pub fn borrow(&self) -> Result> { + let ud = self.to_ref(); + let t = ud.borrow::()?; + // Reattach lifetime to &self + Ok(unsafe { mem::transmute::, Ref>(t) }) + } + + /// Borrow this userdata mutably if it is of type `T`. + /// + /// This is a shortcut for [`AnyUserData::borrow_mut()`] + #[inline] + pub fn borrow_mut(&self) -> Result> { + let ud = self.to_ref(); + let t = ud.borrow_mut::()?; + // Reattach lifetime to &self + Ok(unsafe { mem::transmute::, RefMut>(t) }) + } +} + /// Handle to a `UserData` metatable. #[derive(Clone, Debug)] pub struct UserDataMetatable<'lua>(pub(crate) Table<'lua>); diff --git a/tests/async.rs b/tests/async.rs index 9ba06da..bc1448f 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -271,7 +271,7 @@ async fn test_async_thread() -> Result<()> { } #[test] -fn test_async_thread_leak() -> Result<()> { +fn test_async_thread_capture() -> Result<()> { let lua = Lua::new(); let f = lua.create_async_function(move |_lua, v: Value| async move { @@ -285,10 +285,6 @@ fn test_async_thread_leak() -> Result<()> { thread.resume::<_, ()>("abc").unwrap(); drop(thread); - // Without running garbage collection, the captured `v` would trigger "reference leak detected" error - // with `cfg(mlua_test)` - lua.gc_collect()?; - Ok(()) } @@ -488,3 +484,21 @@ async fn test_async_thread_error() -> Result<()> { Ok(()) } + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[tokio::test] +async fn test_owned_async_call() -> Result<()> { + let lua = Lua::new(); + + let hello = lua + .create_async_function(|_, name: String| async move { + Delay::new(Duration::from_millis(10)).await; + Ok(format!("hello, {}!", name)) + })? + .into_owned(); + drop(lua); + + assert_eq!(hello.call_async::<_, String>("alex").await?, "hello, alex!"); + + Ok(()) +} diff --git a/tests/function.rs b/tests/function.rs index 509a36c..e0cda3e 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -199,3 +199,45 @@ fn test_function_wrap() -> Result<()> { Ok(()) } + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[test] +fn test_owned_function() -> Result<()> { + let lua = Lua::new(); + + let f = lua + .create_function(|_, ()| Ok("hello, world!"))? + .into_owned(); + drop(lua); + + // We still should be able to call the function despite Lua is dropped + let s = f.call::<_, String>(())?; + assert_eq!(s.to_string_lossy(), "hello, world!"); + + Ok(()) +} + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[test] +fn test_owned_function_drop() -> Result<()> { + let rc = std::sync::Arc::new(()); + + { + let lua = Lua::new(); + + lua.set_app_data(rc.clone()); + + let f1 = lua + .create_function(|_, ()| Ok("hello, world!"))? + .into_owned(); + let f2 = + lua.create_function(move |_, ()| f1.to_ref().call::<_, std::string::String>(()))?; + assert_eq!(f2.call::<_, String>(())?.to_string_lossy(), "hello, world!"); + } + + // Check that Lua is properly destroyed + // It works because we collect garbage when Lua goes out of scope + assert_eq!(std::sync::Arc::strong_count(&rc), 1); + + Ok(()) +} diff --git a/tests/table.rs b/tests/table.rs index 148e935..c71b5a3 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -392,3 +392,17 @@ fn test_table_call() -> Result<()> { Ok(()) } + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[test] +fn test_owned_table() -> Result<()> { + let lua = Lua::new(); + + let table = lua.create_table()?.into_owned(); + drop(lua); + + table.to_ref().set("abc", 123)?; + assert_eq!(table.to_ref().get::<_, i64>("abc")?, 123); + + Ok(()) +} diff --git a/tests/userdata.rs b/tests/userdata.rs index 15d27f8..36f05b5 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -820,3 +820,22 @@ fn test_userdata_method_errors() -> Result<()> { Ok(()) } + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[test] +fn test_owned_userdata() -> Result<()> { + let lua = Lua::new(); + + let ud = lua.create_any_userdata("abc")?.into_owned(); + drop(lua); + + assert_eq!(*ud.borrow::<&str>()?, "abc"); + *ud.borrow_mut()? = "cba"; + assert_eq!(*ud.to_ref().borrow::<&str>()?, "cba"); + assert!(matches!( + ud.borrow::(), + Err(Error::UserDataTypeMismatch) + )); + + Ok(()) +}