diff --git a/src/hook.rs b/src/hook.rs index e65c088..185e216 100644 --- a/src/hook.rs +++ b/src/hook.rs @@ -1,5 +1,6 @@ use std::ffi::CStr; use std::marker::PhantomData; +use std::ops::{BitOr, BitOrAssign}; use std::os::raw::{c_char, c_int}; use crate::ffi::{self, lua_Debug, lua_State}; @@ -169,6 +170,46 @@ pub struct HookTriggers { } impl HookTriggers { + /// Returns a new instance of `HookTriggers` with [`on_calls`] trigger set. + /// + /// [`on_calls`]: #structfield.on_calls + pub fn on_calls() -> Self { + HookTriggers { + on_calls: true, + ..Default::default() + } + } + + /// Returns a new instance of `HookTriggers` with [`on_returns`] trigger set. + /// + /// [`on_returns`]: #structfield.on_returns + pub fn on_returns() -> Self { + HookTriggers { + on_returns: true, + ..Default::default() + } + } + + /// Returns a new instance of `HookTriggers` with [`every_line`] trigger set. + /// + /// [`every_line`]: #structfield.every_line + pub fn every_line() -> Self { + HookTriggers { + every_line: true, + ..Default::default() + } + } + + /// Returns a new instance of `HookTriggers` with [`every_nth_instruction`] trigger set. + /// + /// [`every_nth_instruction`]: #structfield.every_nth_instruction + pub fn every_nth_instruction(n: u32) -> Self { + HookTriggers { + every_nth_instruction: Some(n), + ..Default::default() + } + } + // Compute the mask to pass to `lua_sethook`. pub(crate) fn mask(&self) -> c_int { let mut mask: c_int = 0; @@ -194,6 +235,26 @@ impl HookTriggers { } } +impl BitOr for HookTriggers { + type Output = Self; + + fn bitor(mut self, rhs: Self) -> Self::Output { + self.on_calls |= rhs.on_calls; + self.on_returns |= rhs.on_returns; + self.every_line |= rhs.every_line; + if self.every_nth_instruction.is_none() && rhs.every_nth_instruction.is_some() { + self.every_nth_instruction = rhs.every_nth_instruction; + } + self + } +} + +impl BitOrAssign for HookTriggers { + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } +} + pub(crate) unsafe extern "C" fn hook_proc(state: *mut lua_State, ar: *mut lua_Debug) { callback_error(state, |_| { let debug = Debug { diff --git a/src/lua.rs b/src/lua.rs index b799161..98b2fb2 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -628,9 +628,7 @@ impl Lua { /// # use mlua::{Lua, HookTriggers, Result}; /// # fn main() -> Result<()> { /// let lua = Lua::new(); - /// lua.set_hook(HookTriggers { - /// every_line: true, ..Default::default() - /// }, |_lua, debug| { + /// lua.set_hook(HookTriggers::every_line(), |_lua, debug| { /// println!("line {}", debug.curr_line()); /// Ok(()) /// })?; diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 4c6423c..e61a357 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -61,7 +61,7 @@ impl Default for Options { } impl Options { - /// Retruns a new instance of `Options` with default parameters. + /// Returns a new instance of `Options` with default parameters. pub fn new() -> Self { Self::default() } diff --git a/tests/hooks.rs b/tests/hooks.rs index 73a8888..3e8e876 100644 --- a/tests/hooks.rs +++ b/tests/hooks.rs @@ -5,23 +5,30 @@ use std::sync::{Arc, Mutex}; use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value}; +#[test] +fn test_hook_triggers_bitor() { + let trigger = HookTriggers::on_calls() + | HookTriggers::on_returns() + | HookTriggers::every_line() + | HookTriggers::every_nth_instruction(5); + + assert!(trigger.on_calls); + assert!(trigger.on_returns); + assert!(trigger.every_line); + assert_eq!(trigger.every_nth_instruction, Some(5)); +} + #[test] fn test_line_counts() -> Result<()> { let output = Arc::new(Mutex::new(Vec::new())); let hook_output = output.clone(); let lua = Lua::new(); - lua.set_hook( - HookTriggers { - every_line: true, - ..Default::default() - }, - move |_lua, debug| { - assert_eq!(debug.event(), DebugEvent::Line); - hook_output.lock().unwrap().push(debug.curr_line()); - Ok(()) - }, - )?; + lua.set_hook(HookTriggers::every_line(), move |_lua, debug| { + assert_eq!(debug.event(), DebugEvent::Line); + hook_output.lock().unwrap().push(debug.curr_line()); + Ok(()) + })?; lua.load( r#" local x = 2 + 3 @@ -49,21 +56,15 @@ fn test_function_calls() -> Result<()> { let hook_output = output.clone(); let lua = Lua::new(); - lua.set_hook( - HookTriggers { - on_calls: true, - ..Default::default() - }, - move |_lua, debug| { - assert_eq!(debug.event(), DebugEvent::Call); - let names = debug.names(); - let source = debug.source(); - let name = names.name.map(|s| str::from_utf8(s).unwrap().to_owned()); - let what = source.what.map(|s| str::from_utf8(s).unwrap().to_owned()); - hook_output.lock().unwrap().push((name, what)); - Ok(()) - }, - )?; + lua.set_hook(HookTriggers::on_calls(), move |_lua, debug| { + assert_eq!(debug.event(), DebugEvent::Call); + let names = debug.names(); + let source = debug.source(); + let name = names.name.map(|s| str::from_utf8(s).unwrap().to_owned()); + let what = source.what.map(|s| str::from_utf8(s).unwrap().to_owned()); + hook_output.lock().unwrap().push((name, what)); + Ok(()) + })?; lua.load( r#" @@ -99,17 +100,12 @@ fn test_function_calls() -> Result<()> { #[test] fn test_error_within_hook() -> Result<()> { let lua = Lua::new(); - lua.set_hook( - HookTriggers { - every_line: true, - ..Default::default() - }, - |_lua, _debug| { - Err(Error::RuntimeError( - "Something happened in there!".to_string(), - )) - }, - )?; + + lua.set_hook(HookTriggers::every_line(), |_lua, _debug| { + Err(Error::RuntimeError( + "Something happened in there!".to_string(), + )) + })?; let err = lua .load("x = 1") @@ -137,10 +133,7 @@ fn test_limit_execution_instructions() -> Result<()> { lua.load("jit.off()").exec()?; lua.set_hook( - HookTriggers { - every_nth_instruction: Some(30), - ..Default::default() - }, + HookTriggers::every_nth_instruction(30), move |_lua, debug| { assert_eq!(debug.event(), DebugEvent::Count); max_instructions -= 30; @@ -171,17 +164,11 @@ fn test_limit_execution_instructions() -> Result<()> { fn test_hook_removal() -> Result<()> { let lua = Lua::new(); - lua.set_hook( - HookTriggers { - every_nth_instruction: Some(1), - ..Default::default() - }, - |_lua, _debug| { - Err(Error::RuntimeError( - "this hook should've been removed by this time".to_string(), - )) - }, - )?; + lua.set_hook(HookTriggers::every_nth_instruction(1), |_lua, _debug| { + Err(Error::RuntimeError( + "this hook should've been removed by this time".to_string(), + )) + })?; assert!(lua.load("local x = 1").exec().is_err()); lua.remove_hook(); @@ -201,19 +188,14 @@ fn test_hook_swap_within_hook() -> Result<()> { }); TL_LUA.with(|tl| { - tl.borrow().as_ref().unwrap().set_hook( - HookTriggers { - every_line: true, - ..Default::default() - }, - move |lua, _debug| { + tl.borrow() + .as_ref() + .unwrap() + .set_hook(HookTriggers::every_line(), move |lua, _debug| { lua.globals().set("ok", 1i64)?; TL_LUA.with(|tl| { tl.borrow().as_ref().unwrap().set_hook( - HookTriggers { - every_line: true, - ..Default::default() - }, + HookTriggers::every_line(), move |lua, _debug| { lua.load( r#" @@ -231,8 +213,7 @@ fn test_hook_swap_within_hook() -> Result<()> { }, ) }) - }, - ) + }) })?; TL_LUA.with(|tl| {