diff --git a/src/lua.rs b/src/lua.rs index e82a4ca..3f22774 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -868,7 +868,7 @@ impl Lua { #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn set_hook(&self, triggers: HookTriggers, callback: F) -> Result<()> where - F: 'static + MaybeSend + FnMut(&Lua, Debug) -> Result<()>, + F: 'static + MaybeSend + Fn(&Lua, Debug) -> Result<()>, { unsafe extern "C" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) { let lua = match Lua::make_from_ptr(state) { @@ -880,29 +880,24 @@ impl Lua { let debug = Debug::new(&lua, ar); let hook_cb = (*lua.extra.get()).hook_callback.clone(); let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc"); - - #[allow(clippy::match_wild_err_arm)] - match hook_cb.try_lock() { - Ok(mut cb) => cb(&lua, debug), - Err(_) => { - mlua_panic!("Lua should not allow hooks to be called within another hook") - } - }?; - - Ok(()) + if Arc::strong_count(&hook_cb) > 2 { + return Ok(()); // Don't allow recursion + } + hook_cb(&lua, debug) }) } let state = self.main_state.ok_or(Error::MainThreadNotAvailable)?; unsafe { - (*self.extra.get()).hook_callback = Some(Arc::new(Mutex::new(callback))); + (*self.extra.get()).hook_callback = Some(Arc::new(callback)); ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count()); } Ok(()) } - /// Remove any hook previously set by `set_hook`. This function has no effect if a hook was not - /// previously set. + /// Removes any hook previously set by `set_hook`. + /// + /// This function has no effect if a hook was not previously set. #[cfg(not(feature = "luau"))] #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn remove_hook(&self) { diff --git a/src/prelude.rs b/src/prelude.rs index 28a90e5..1dea012 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -15,6 +15,10 @@ pub use crate::{ Value as LuaValue, }; +#[cfg(not(feature = "luau"))] +#[doc(no_inline)] +pub use crate::HookTriggers as LuaHookTriggers; + #[cfg(feature = "luau")] #[doc(no_inline)] pub use crate::VmState as LuaVmState; diff --git a/src/types.rs b/src/types.rs index 6ff7a21..47d6de3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -49,13 +49,8 @@ pub(crate) struct AsyncPollUpvalue<'lua> { pub(crate) lua: Lua, pub(crate) fut: LocalBoxFuture<'lua, Result>>, } -#[cfg(all(feature = "send", not(feature = "luau")))] -pub(crate) type HookCallback = Arc Result<()> + Send>>; -#[cfg(all(not(feature = "send"), not(feature = "luau")))] -pub(crate) type HookCallback = Arc Result<()>>>; - -/// Type to set next Lua VM action after executing interrupt function. +/// Type to set next Luau VM action after executing interrupt function. #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub enum VmState { @@ -63,6 +58,12 @@ pub enum VmState { Yield, } +#[cfg(all(feature = "send", not(feature = "luau")))] +pub(crate) type HookCallback = Arc Result<()> + Send>; + +#[cfg(all(not(feature = "send"), not(feature = "luau")))] +pub(crate) type HookCallback = Arc Result<()>>; + #[cfg(all(feature = "luau", feature = "send"))] pub(crate) type InterruptCallback = Arc Result + Send>; diff --git a/tests/hooks.rs b/tests/hooks.rs index 8c44b5d..2fbd33d 100644 --- a/tests/hooks.rs +++ b/tests/hooks.rs @@ -3,6 +3,7 @@ use std::cell::RefCell; use std::ops::Deref; use std::str; +use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::{Arc, Mutex}; use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value}; @@ -128,18 +129,17 @@ fn test_error_within_hook() -> Result<()> { #[test] fn test_limit_execution_instructions() -> Result<()> { let lua = Lua::new(); - let mut max_instructions = 10000; - #[cfg(feature = "luajit")] // For LuaJIT disable JIT, as compiled code does not trigger hooks + #[cfg(feature = "luajit")] lua.load("jit.off()").exec()?; + let max_instructions = AtomicI64::new(10000); lua.set_hook( HookTriggers::every_nth_instruction(30), move |_lua, debug| { assert_eq!(debug.event(), DebugEvent::Count); - max_instructions -= 30; - if max_instructions < 0 { + if max_instructions.fetch_sub(30, Ordering::Relaxed) <= 30 { Err(Error::RuntimeError("time's up".to_string())) } else { Ok(())