diff --git a/src/lua.rs b/src/lua.rs index 60f8bc9..de713b7 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -1712,15 +1712,16 @@ impl Lua { all(feature = "luajit", feature = "vendored"), feature = "luau", ))] - pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) { + pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) -> bool { let extra = &mut *self.extra.get(); - let thread_state = ffi::lua_tothread(extra.ref_thread, thread.0.index); if extra.recycled_thread_cache.len() < extra.recycled_thread_cache.capacity() { + let thread_state = ffi::lua_tothread(extra.ref_thread, thread.0.index); #[cfg(feature = "lua54")] let status = ffi::lua_resetthread(thread_state); #[cfg(feature = "lua54")] if status != ffi::LUA_OK { - return; + // Error object is on top, drop it + ffi::lua_settop(thread_state, 0); } #[cfg(all(feature = "luajit", feature = "vendored"))] ffi::lua_resetthread(self.state, thread_state); @@ -1728,7 +1729,9 @@ impl Lua { ffi::lua_resetthread(thread_state); extra.recycled_thread_cache.push(thread.0.index); thread.0.index = 0; + return true; } + false } /// Create a Lua userdata object from a custom userdata type. diff --git a/src/thread.rs b/src/thread.rs index dcdb565..6feab48 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -357,7 +357,16 @@ impl<'lua, R> Drop for AsyncThread<'lua, R> { fn drop(&mut self) { if self.recycle { unsafe { - self.thread.0.lua.recycle_thread(&mut self.thread); + let lua = self.thread.0.lua; + // For Lua 5.4 this also closes all pending to-be-closed variables + if !lua.recycle_thread(&mut self.thread) { + #[cfg(feature = "lua54")] + if self.thread.status() == ThreadStatus::Error { + let thread_state = + lua.ref_thread_exec(|t| ffi::lua_tothread(t, self.thread.0.index)); + ffi::lua_resetthread(thread_state); + } + } } } } diff --git a/tests/async.rs b/tests/async.rs index fcfb9c7..93fc125 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -174,6 +174,38 @@ async fn test_async_return_async_closure() -> Result<()> { Ok(()) } +#[cfg(feature = "lua54")] +#[tokio::test] +async fn test_async_lua54_to_be_closed() -> Result<()> { + let lua = Lua::new(); + + let globals = lua.globals(); + globals.set("close_count", 0)?; + + let code = r#" + local t = setmetatable({}, { + __close = function() + close_count = close_count + 1 + end + }) + error "test" + "#; + let f = lua.load(code).into_function()?; + + // Test close using call_async + let _ = f.call_async::<_, ()>(()).await; + assert_eq!(globals.get::<_, usize>("close_count")?, 1); + + // Don't close by default when awaiting async threads + let co = lua.create_thread(f.clone())?; + let _ = co.clone().into_async::<_, ()>(()).await; + assert_eq!(globals.get::<_, usize>("close_count")?, 1); + let _ = co.reset(f); + assert_eq!(globals.get::<_, usize>("close_count")?, 2); + + Ok(()) +} + #[tokio::test] async fn test_async_thread_stream() -> Result<()> { let lua = Lua::new(); @@ -278,6 +310,28 @@ async fn test_async_table() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_async_thread_cache() -> Result<()> { + let options = LuaOptions::new().thread_cache_size(4); + let lua = Lua::new_with(StdLib::ALL_SAFE, options)?; + + let error_f = lua.create_async_function(|_, ()| async move { + Delay::new(Duration::from_millis(10)).await; + Err::<(), _>(Error::RuntimeError("test".to_string())) + })?; + + let sleep = lua.create_async_function(|_, n| async move { + Delay::new(Duration::from_millis(n)).await; + Ok(format!("elapsed:{}ms", n)) + })?; + + assert!(error_f.call_async::<_, ()>(()).await.is_err()); + // Next call should use cached thread + assert_eq!(sleep.call_async::<_, String>(3).await?, "elapsed:3ms"); + + Ok(()) +} + #[tokio::test] async fn test_async_userdata() -> Result<()> { #[derive(Clone)]