diff --git a/src/lua.rs b/src/lua.rs index a8774f6..de92c2e 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -121,6 +121,8 @@ pub(crate) struct ExtraData { #[cfg(not(feature = "luau"))] hook_callback: Option, + #[cfg(not(feature = "luau"))] + hook_thread: *mut ffi::lua_State, #[cfg(feature = "lua54")] warn_callback: Option, #[cfg(feature = "luau")] @@ -511,6 +513,8 @@ impl Lua { waker: NonNull::from(noop_waker_ref()), #[cfg(not(feature = "luau"))] hook_callback: None, + #[cfg(not(feature = "luau"))] + hook_thread: ptr::null_mut(), #[cfg(feature = "lua54")] warn_callback: None, #[cfg(feature = "luau")] @@ -815,6 +819,11 @@ impl Lua { /// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and /// erroring once an instruction limit has been reached. /// + /// This method sets a hook function for the main thread (if available) of this Lua instance. + /// If you want to set a hook function for a thread (coroutine), use [`Thread::set_hook()`] instead. + /// + /// Please note you cannot have more than one hook function set at a time for this Lua instance. + /// /// # Example /// /// Shows each line number of code being executed by the Lua interpreter. @@ -823,7 +832,7 @@ impl Lua { /// # use mlua::{Lua, HookTriggers, Result}; /// # fn main() -> Result<()> { /// let lua = Lua::new(); - /// lua.set_hook(HookTriggers::every_line(), |_lua, debug| { + /// lua.set_hook(HookTriggers::EVERY_LINE, |_lua, debug| { /// println!("line {}", debug.curr_line()); /// Ok(()) /// })?; @@ -842,47 +851,71 @@ impl Lua { #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn set_hook(&self, triggers: HookTriggers, callback: F) -> Result<()> where - F: 'static + MaybeSend + Fn(&Lua, Debug) -> Result<()>, + F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static, + { + unsafe { + let state = get_main_state(self.main_state).ok_or(Error::MainThreadNotAvailable)?; + self.set_thread_hook(state, triggers, callback); + } + Ok(()) + } + + /// Sets a 'hook' function for a thread (coroutine). + #[cfg(not(feature = "luau"))] + pub(crate) unsafe fn set_thread_hook( + &self, + state: *mut ffi::lua_State, + triggers: HookTriggers, + callback: F, + ) where + F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static, { unsafe extern "C" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) { - let lua = match Lua::try_from_ptr(state) { - Some(lua) => lua, - None => return, - }; - let extra = lua.extra.get(); + let extra = extra_data(state); + if extra.is_null() { + return; + } + if (*extra).hook_thread != state { + // Hook was destined for a different thread, ignore + ffi::lua_sethook(state, None, 0, 0); + return; + } callback_error_ext(state, extra, move |_| { - let debug = Debug::new(&lua, ar); let hook_cb = (*extra).hook_callback.clone(); let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc"); if Arc::strong_count(&hook_cb) > 2 { return Ok(()); // Don't allow recursion } - hook_cb(&lua, debug) + let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); + let _guard = StateGuard::new(&lua.0, state); + let debug = Debug::new(lua, ar); + hook_cb(lua, debug) }) } - unsafe { - let state = get_main_state(self.main_state).ok_or(Error::MainThreadNotAvailable)?; - (*self.extra.get()).hook_callback = Some(Arc::new(callback)); - ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count()); - } - Ok(()) + (*self.extra.get()).hook_callback = Some(Arc::new(callback)); + (*self.extra.get()).hook_thread = state; // Mark for what thread the hook is set + ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count()); } - /// Removes any hook previously set by `set_hook`. + /// Removes any hook previously set by [`Lua::set_hook()`] or [`Thread::set_hook()`]. /// /// This function has no effect if a hook was not previously set. #[cfg(not(feature = "luau"))] #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn remove_hook(&self) { unsafe { - // If main_state is not available, then sethook wasn't called. - let state = match get_main_state(self.main_state) { - Some(state) => state, - None => return, + let state = self.state(); + ffi::lua_sethook(state, None, 0, 0); + match get_main_state(self.main_state) { + Some(main_state) if !ptr::eq(state, main_state) => { + // If main_state is different from state, remove hook from it too + ffi::lua_sethook(main_state, None, 0, 0); + } + _ => {} }; (*self.extra.get()).hook_callback = None; - ffi::lua_sethook(state, None, 0, 0); + (*self.extra.get()).hook_thread = ptr::null_mut(); } } @@ -951,6 +984,7 @@ impl Lua { return Ok(VmState::Continue); // Don't allow recursion } let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); + let _guard = StateGuard::new(&lua.0, state); interrupt_cb(lua) }); match result { diff --git a/src/thread.rs b/src/thread.rs index 00e9226..174cb97 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -3,6 +3,8 @@ use std::os::raw::c_int; use crate::error::{Error, Result}; use crate::ffi; +#[allow(unused)] +use crate::lua::Lua; use crate::types::LuaRef; use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard}; use crate::value::{FromLuaMulti, IntoLuaMulti}; @@ -14,10 +16,16 @@ use crate::value::{FromLuaMulti, IntoLuaMulti}; ))] use crate::function::Function; +#[cfg(not(feature = "luau"))] +use crate::{ + hook::{Debug, HookTriggers}, + types::MaybeSend, +}; + #[cfg(feature = "async")] use { crate::{ - lua::{Lua, ASYNC_POLL_PENDING}, + lua::ASYNC_POLL_PENDING, value::{MultiValue, Value}, }, futures_core::{future::Future, stream::Stream}, @@ -178,6 +186,23 @@ impl<'lua> Thread<'lua> { } } + /// Sets a 'hook' function that will periodically be called as Lua code executes. + /// + /// This function is similar or [`Lua::set_hook()`] except that it sets for the thread. + /// To remove a hook call [`Lua::remove_hook()`]. + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn set_hook(&self, triggers: HookTriggers, callback: F) + where + F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static, + { + let lua = self.0.lua; + unsafe { + let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); + lua.set_thread_hook(thread_state, triggers, callback); + } + } + /// Resets a thread /// /// In [Lua 5.4]: cleans its call stack and closes all pending to-be-closed variables. diff --git a/tests/hooks.rs b/tests/hooks.rs index 2549f10..f9a7023 100644 --- a/tests/hooks.rs +++ b/tests/hooks.rs @@ -9,12 +9,9 @@ use std::sync::{Arc, Mutex}; use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value}; #[test] -fn test_hook_triggers_bitor() { - let trigger = HookTriggers::new() - .on_calls() - .on_returns() - .every_line() - .every_nth_instruction(5); +fn test_hook_triggers() { + let trigger = HookTriggers::new().on_calls().on_returns() + | HookTriggers::new().every_line().every_nth_instruction(5); assert!(trigger.on_calls); assert!(trigger.on_returns); @@ -237,3 +234,39 @@ fn test_hook_swap_within_hook() -> Result<()> { Ok(()) }) } + +#[test] +fn test_hook_threads() -> Result<()> { + let lua = Lua::new(); + + let func = lua + .load( + r#" + local x = 2 + 3 + local y = x * 63 + local z = string.len(x..", "..y) + "#, + ) + .into_function()?; + let co = lua.create_thread(func)?; + + let output = Arc::new(Mutex::new(Vec::new())); + let hook_output = output.clone(); + co.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| { + assert_eq!(debug.event(), DebugEvent::Line); + hook_output.lock().unwrap().push(debug.curr_line()); + Ok(()) + }); + + co.resume(())?; + lua.remove_hook(); + + let output = output.lock().unwrap(); + if cfg!(feature = "luajit") && lua.load("jit.version_num").eval::()? >= 20100 { + assert_eq!(*output, vec![2, 3, 4, 0, 4]); + } else { + assert_eq!(*output, vec![2, 3, 4]); + } + + Ok(()) +}