Add `Thread::set_hook()` function

This commit is contained in:
Alex Orlenko 2023-04-08 23:53:48 +01:00
parent 288934c82c
commit 3e83753466
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
3 changed files with 120 additions and 28 deletions

View File

@ -121,6 +121,8 @@ pub(crate) struct ExtraData {
#[cfg(not(feature = "luau"))]
hook_callback: Option<HookCallback>,
#[cfg(not(feature = "luau"))]
hook_thread: *mut ffi::lua_State,
#[cfg(feature = "lua54")]
warn_callback: Option<WarnCallback>,
#[cfg(feature = "luau")]
@ -511,6 +513,8 @@ impl Lua {
waker: NonNull::from(noop_waker_ref()),
#[cfg(not(feature = "luau"))]
hook_callback: None,
#[cfg(not(feature = "luau"))]
hook_thread: ptr::null_mut(),
#[cfg(feature = "lua54")]
warn_callback: None,
#[cfg(feature = "luau")]
@ -815,6 +819,11 @@ impl Lua {
/// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and
/// erroring once an instruction limit has been reached.
///
/// This method sets a hook function for the main thread (if available) of this Lua instance.
/// If you want to set a hook function for a thread (coroutine), use [`Thread::set_hook()`] instead.
///
/// Please note you cannot have more than one hook function set at a time for this Lua instance.
///
/// # Example
///
/// Shows each line number of code being executed by the Lua interpreter.
@ -823,7 +832,7 @@ impl Lua {
/// # use mlua::{Lua, HookTriggers, Result};
/// # fn main() -> Result<()> {
/// let lua = Lua::new();
/// lua.set_hook(HookTriggers::every_line(), |_lua, debug| {
/// lua.set_hook(HookTriggers::EVERY_LINE, |_lua, debug| {
/// println!("line {}", debug.curr_line());
/// Ok(())
/// })?;
@ -842,47 +851,71 @@ impl Lua {
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
where
F: 'static + MaybeSend + Fn(&Lua, Debug) -> Result<()>,
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
{
unsafe {
let state = get_main_state(self.main_state).ok_or(Error::MainThreadNotAvailable)?;
self.set_thread_hook(state, triggers, callback);
}
Ok(())
}
/// Sets a 'hook' function for a thread (coroutine).
#[cfg(not(feature = "luau"))]
pub(crate) unsafe fn set_thread_hook<F>(
&self,
state: *mut ffi::lua_State,
triggers: HookTriggers,
callback: F,
) where
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
{
unsafe extern "C" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
let lua = match Lua::try_from_ptr(state) {
Some(lua) => lua,
None => return,
};
let extra = lua.extra.get();
let extra = extra_data(state);
if extra.is_null() {
return;
}
if (*extra).hook_thread != state {
// Hook was destined for a different thread, ignore
ffi::lua_sethook(state, None, 0, 0);
return;
}
callback_error_ext(state, extra, move |_| {
let debug = Debug::new(&lua, ar);
let hook_cb = (*extra).hook_callback.clone();
let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc");
if Arc::strong_count(&hook_cb) > 2 {
return Ok(()); // Don't allow recursion
}
hook_cb(&lua, debug)
let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
let _guard = StateGuard::new(&lua.0, state);
let debug = Debug::new(lua, ar);
hook_cb(lua, debug)
})
}
unsafe {
let state = get_main_state(self.main_state).ok_or(Error::MainThreadNotAvailable)?;
(*self.extra.get()).hook_callback = Some(Arc::new(callback));
ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count());
}
Ok(())
(*self.extra.get()).hook_callback = Some(Arc::new(callback));
(*self.extra.get()).hook_thread = state; // Mark for what thread the hook is set
ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count());
}
/// Removes any hook previously set by `set_hook`.
/// Removes any hook previously set by [`Lua::set_hook()`] or [`Thread::set_hook()`].
///
/// This function has no effect if a hook was not previously set.
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn remove_hook(&self) {
unsafe {
// If main_state is not available, then sethook wasn't called.
let state = match get_main_state(self.main_state) {
Some(state) => state,
None => return,
let state = self.state();
ffi::lua_sethook(state, None, 0, 0);
match get_main_state(self.main_state) {
Some(main_state) if !ptr::eq(state, main_state) => {
// If main_state is different from state, remove hook from it too
ffi::lua_sethook(main_state, None, 0, 0);
}
_ => {}
};
(*self.extra.get()).hook_callback = None;
ffi::lua_sethook(state, None, 0, 0);
(*self.extra.get()).hook_thread = ptr::null_mut();
}
}
@ -951,6 +984,7 @@ impl Lua {
return Ok(VmState::Continue); // Don't allow recursion
}
let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
let _guard = StateGuard::new(&lua.0, state);
interrupt_cb(lua)
});
match result {

View File

@ -3,6 +3,8 @@ use std::os::raw::c_int;
use crate::error::{Error, Result};
use crate::ffi;
#[allow(unused)]
use crate::lua::Lua;
use crate::types::LuaRef;
use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
use crate::value::{FromLuaMulti, IntoLuaMulti};
@ -14,10 +16,16 @@ use crate::value::{FromLuaMulti, IntoLuaMulti};
))]
use crate::function::Function;
#[cfg(not(feature = "luau"))]
use crate::{
hook::{Debug, HookTriggers},
types::MaybeSend,
};
#[cfg(feature = "async")]
use {
crate::{
lua::{Lua, ASYNC_POLL_PENDING},
lua::ASYNC_POLL_PENDING,
value::{MultiValue, Value},
},
futures_core::{future::Future, stream::Stream},
@ -178,6 +186,23 @@ impl<'lua> Thread<'lua> {
}
}
/// Sets a 'hook' function that will periodically be called as Lua code executes.
///
/// This function is similar or [`Lua::set_hook()`] except that it sets for the thread.
/// To remove a hook call [`Lua::remove_hook()`].
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
where
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
{
let lua = self.0.lua;
unsafe {
let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index);
lua.set_thread_hook(thread_state, triggers, callback);
}
}
/// Resets a thread
///
/// In [Lua 5.4]: cleans its call stack and closes all pending to-be-closed variables.

View File

@ -9,12 +9,9 @@ use std::sync::{Arc, Mutex};
use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value};
#[test]
fn test_hook_triggers_bitor() {
let trigger = HookTriggers::new()
.on_calls()
.on_returns()
.every_line()
.every_nth_instruction(5);
fn test_hook_triggers() {
let trigger = HookTriggers::new().on_calls().on_returns()
| HookTriggers::new().every_line().every_nth_instruction(5);
assert!(trigger.on_calls);
assert!(trigger.on_returns);
@ -237,3 +234,39 @@ fn test_hook_swap_within_hook() -> Result<()> {
Ok(())
})
}
#[test]
fn test_hook_threads() -> Result<()> {
let lua = Lua::new();
let func = lua
.load(
r#"
local x = 2 + 3
local y = x * 63
local z = string.len(x..", "..y)
"#,
)
.into_function()?;
let co = lua.create_thread(func)?;
let output = Arc::new(Mutex::new(Vec::new()));
let hook_output = output.clone();
co.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| {
assert_eq!(debug.event(), DebugEvent::Line);
hook_output.lock().unwrap().push(debug.curr_line());
Ok(())
});
co.resume(())?;
lua.remove_hook();
let output = output.lock().unwrap();
if cfg!(feature = "luajit") && lua.load("jit.version_num").eval::<i64>()? >= 20100 {
assert_eq!(*output, vec![2, 3, 4, 0, 4]);
} else {
assert_eq!(*output, vec![2, 3, 4]);
}
Ok(())
}