From c3822219e0a25d0b564f0e7a766b85aa5719d77d Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Fri, 22 May 2020 00:35:03 +0100 Subject: [PATCH] Add hooks support (based on rlua v0.17 implementation) This feature works on lua54, lua53, lua52 and lua51 only. LuaJIT is unstable. --- src/ffi/lua.rs | 4 +- src/hook.rs | 204 +++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 10 +++ src/lua.rs | 154 +++++++++++++++++++++++++++++++-- src/types.rs | 4 + tests/hooks.rs | 230 +++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 598 insertions(+), 8 deletions(-) create mode 100644 src/hook.rs create mode 100644 tests/hooks.rs diff --git a/src/ffi/lua.rs b/src/ffi/lua.rs index bade2c3..16d163d 100644 --- a/src/ffi/lua.rs +++ b/src/ffi/lua.rs @@ -739,7 +739,7 @@ pub const LUA_MASKLINE: c_int = 1 << (LUA_HOOKLINE as usize); pub const LUA_MASKCOUNT: c_int = 1 << (LUA_HOOKCOUNT as usize); /// Type for functions to be called on debug events. -pub type lua_Hook = extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug); +pub type lua_Hook = unsafe extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug); extern "C" { pub fn lua_getstack(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) -> c_int; @@ -754,7 +754,7 @@ extern "C" { #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] pub fn lua_upvaluejoin(L: *mut lua_State, fidx1: c_int, n1: c_int, fidx2: c_int, n2: c_int); - pub fn lua_sethook(L: *mut lua_State, func: lua_Hook, mask: c_int, count: c_int); + pub fn lua_sethook(L: *mut lua_State, func: Option, mask: c_int, count: c_int); pub fn lua_gethook(L: *mut lua_State) -> Option; pub fn lua_gethookmask(L: *mut lua_State) -> c_int; pub fn lua_gethookcount(L: *mut lua_State) -> c_int; diff --git a/src/hook.rs b/src/hook.rs new file mode 100644 index 0000000..bf3a26c --- /dev/null +++ b/src/hook.rs @@ -0,0 +1,204 @@ +#![cfg_attr( + not(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51")), + allow(dead_code) +)] + +use std::ffi::CStr; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int}; + +use crate::ffi::{self, lua_Debug, lua_State}; +use crate::lua::Lua; +use crate::util::callback_error; + +/// Contains information about currently executing Lua code. +/// +/// The `Debug` structure is provided as a parameter to the hook function set with +/// [`Lua::set_hook`]. You may call the methods on this structure to retrieve information about the +/// Lua code executing at the time that the hook function was called. Further information can be +/// found in the [Lua 5.3 documentaton][lua_doc]. +/// +/// Requires `feature = "lua54/lua53/lua52/lua51"` +/// +/// [lua_doc]: https://www.lua.org/manual/5.3/manual.html#lua_Debug +/// [`Lua::set_hook`]: struct.Lua.html#method.set_hook +#[derive(Clone)] +pub struct Debug<'a> { + ar: *mut lua_Debug, + state: *mut lua_State, + _phantom: PhantomData<&'a ()>, +} + +impl<'a> Debug<'a> { + /// Corresponds to the `n` what mask. + pub fn names(&self) -> DebugNames<'a> { + unsafe { + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("n"), self.ar) != 0, + "lua_getinfo failed with `n`" + ); + DebugNames { + name: ptr_to_str((*self.ar).name), + name_what: ptr_to_str((*self.ar).namewhat), + } + } + } + + /// Corresponds to the `n` what mask. + pub fn source(&self) -> DebugSource<'a> { + unsafe { + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("S"), self.ar) != 0, + "lua_getinfo failed with `S`" + ); + DebugSource { + source: ptr_to_str((*self.ar).source), + short_src: ptr_to_str((*self.ar).short_src.as_ptr()), + line_defined: (*self.ar).linedefined as i32, + last_line_defined: (*self.ar).lastlinedefined as i32, + what: ptr_to_str((*self.ar).what), + } + } + } + + /// Corresponds to the `l` what mask. Returns the current line. + pub fn curr_line(&self) -> i32 { + unsafe { + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("l"), self.ar) != 0, + "lua_getinfo failed with `l`" + ); + (*self.ar).currentline as i32 + } + } + + /// Corresponds to the `t` what mask. Returns true if the hook is in a function tail call, false + /// otherwise. + pub fn is_tail_call(&self) -> bool { + unsafe { + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("t"), self.ar) != 0, + "lua_getinfo failed with `t`" + ); + (*self.ar).currentline != 0 + } + } + + /// Corresponds to the `u` what mask. + pub fn stack(&self) -> DebugStack { + unsafe { + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("u"), self.ar) != 0, + "lua_getinfo failed with `u`" + ); + DebugStack { + num_ups: (*self.ar).nups as i32, + #[cfg(any(feature = "lua52", feature = "lua53", feature = "lua54"))] + num_params: (*self.ar).nparams as i32, + #[cfg(any(feature = "lua52", feature = "lua53", feature = "lua54"))] + is_vararg: (*self.ar).isvararg != 0, + } + } + } +} + +#[derive(Clone, Debug)] +pub struct DebugNames<'a> { + pub name: Option<&'a [u8]>, + pub name_what: Option<&'a [u8]>, +} + +#[derive(Clone, Debug)] +pub struct DebugSource<'a> { + pub source: Option<&'a [u8]>, + pub short_src: Option<&'a [u8]>, + pub line_defined: i32, + pub last_line_defined: i32, + pub what: Option<&'a [u8]>, +} + +#[derive(Copy, Clone, Debug)] +pub struct DebugStack { + pub num_ups: i32, + /// Requires `feature = "lua54/lua53/lua52"` + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", doc))] + pub num_params: i32, + /// Requires `feature = "lua54/lua53/lua52"` + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", doc))] + pub is_vararg: bool, +} + +/// Determines when a hook function will be called by Lua. +/// +/// Requires `feature = "lua54/lua53/lua52/lua51"` +#[derive(Clone, Copy, Debug, Default)] +pub struct HookTriggers { + /// Before a function call. + pub on_calls: bool, + /// When Lua returns from a function. + pub on_returns: bool, + /// Before executing a new line, or returning from a function call. + pub every_line: bool, + /// After a certain number of VM instructions have been executed. When set to `Some(count)`, + /// `count` is the number of VM instructions to execute before calling the hook. + /// + /// # Performance + /// + /// Setting this option to a low value can incur a very high overhead. + pub every_nth_instruction: Option, +} + +impl HookTriggers { + // Compute the mask to pass to `lua_sethook`. + pub(crate) fn mask(&self) -> c_int { + let mut mask: c_int = 0; + if self.on_calls { + mask |= ffi::LUA_MASKCALL + } + if self.on_returns { + mask |= ffi::LUA_MASKRET + } + if self.every_line { + mask |= ffi::LUA_MASKLINE + } + if self.every_nth_instruction.is_some() { + mask |= ffi::LUA_MASKCOUNT + } + mask + } + + // Returns the `count` parameter to pass to `lua_sethook`, if applicable. Otherwise, zero is + // returned. + pub(crate) fn count(&self) -> c_int { + self.every_nth_instruction.unwrap_or(0) as c_int + } +} + +pub(crate) unsafe extern "C" fn hook_proc(state: *mut lua_State, ar: *mut lua_Debug) { + callback_error(state, |_| { + let debug = Debug { + ar, + state, + _phantom: PhantomData, + }; + + let lua = Lua::make_from_ptr(state); + let hook_cb = mlua_expect!(lua.hook_callback(), "no hook callback set in hook_proc"); + + #[allow(clippy::match_wild_err_arm)] + match hook_cb.try_borrow_mut() { + Ok(mut b) => (&mut *b)(&lua, debug), + Err(_) => mlua_panic!("Lua should not allow hooks to be called within another hook"), + }?; + + Ok(()) + }); +} + +unsafe fn ptr_to_str<'a>(input: *const c_char) -> Option<&'a [u8]> { + if input.is_null() { + None + } else { + Some(CStr::from_ptr(input).to_bytes()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 0a03e69..b4b6606 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,7 @@ mod conversion; mod error; mod ffi; mod function; +mod hook; mod lua; mod multi; mod scope; @@ -90,6 +91,15 @@ pub use crate::types::{Integer, LightUserData, Number, RegistryKey}; pub use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; pub use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; +#[cfg(any( + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "lua51", + doc +))] +pub use crate::hook::{Debug, DebugNames, DebugSource, DebugStack, HookTriggers}; + #[cfg(feature = "async")] pub use crate::thread::AsyncThread; diff --git a/src/lua.rs b/src/lua.rs index 78b295a..e13dcb5 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::ffi::CString; use std::marker::PhantomData; use std::os::raw::{c_char, c_int, c_void}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, Weak}; use std::{mem, ptr, str}; use crate::error::{Error, Result}; @@ -15,7 +15,9 @@ use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; -use crate::types::{Callback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey}; +use crate::types::{ + Callback, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey, +}; use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; use crate::util::{ assert_stack, callback_error, check_stack, get_gc_userdata, get_main_state, @@ -25,6 +27,15 @@ use crate::util::{ }; use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; +#[cfg(any( + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "lua51", + doc +))] +use crate::hook::{hook_proc, Debug, HookTriggers}; + #[cfg(feature = "async")] use { crate::types::AsyncCallback, @@ -58,6 +69,8 @@ struct ExtraData { ref_stack_size: c_int, ref_stack_max: c_int, ref_free: Vec, + + hook_callback: Option, } #[cfg_attr(any(feature = "lua51", feature = "luajit"), allow(dead_code))] @@ -85,6 +98,7 @@ pub enum GCMode { pub(crate) struct AsyncPollPending; #[cfg(feature = "async")] pub(crate) static WAKER_REGISTRY_KEY: u8 = 0; +pub(crate) static EXTRA_REGISTRY_KEY: u8 = 0; /// Requires `feature = "send"` #[cfg(feature = "send")] @@ -276,6 +290,7 @@ impl Lua { init_gc_metatable_for::(state, None); init_gc_metatable_for::(state, None); + init_gc_metatable_for::>>(state, None); #[cfg(feature = "async")] { init_gc_metatable_for::(state, None); @@ -305,8 +320,24 @@ impl Lua { ref_stack_size: ffi::LUA_MINSTACK - 1, ref_stack_max: 0, ref_free: Vec::new(), + hook_callback: None, })); + mlua_expect!( + push_gc_userdata(state, Arc::downgrade(&extra)), + "Error while storing extra data", + ); + mlua_expect!( + protect_lua_closure(main_state, 1, 0, |state| { + ffi::lua_rawsetp( + state, + ffi::LUA_REGISTRYINDEX, + &EXTRA_REGISTRY_KEY as *const u8 as *mut c_void, + ); + }), + "Error while storing extra data" + ); + mlua_debug_assert!( ffi::lua_gettop(main_state) == main_state_top, "stack leak during creation" @@ -387,6 +418,90 @@ impl Lua { unsafe { self.push_value(cb.call(())?).map(|_| 1) } } + /// Sets a 'hook' function that will periodically be called as Lua code executes. + /// + /// When exactly the hook function is called depends on the contents of the `triggers` + /// parameter, see [`HookTriggers`] for more details. + /// + /// The provided hook function can error, and this error will be propagated through the Lua code + /// that was executing at the time the hook was triggered. This can be used to implement a + /// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and + /// erroring once an instruction limit has been reached. + /// + /// Requires `feature = "lua54/lua53/lua52/lua51"` + /// + /// # Example + /// + /// Shows each line number of code being executed by the Lua interpreter. + /// + /// ``` + /// # #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51"))] + /// # use mlua::{Lua, HookTriggers, Result}; + /// # #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51"))] + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// lua.set_hook(HookTriggers { + /// every_line: true, ..Default::default() + /// }, |_lua, debug| { + /// println!("line {}", debug.curr_line()); + /// Ok(()) + /// }); + /// + /// lua.load(r#" + /// local x = 2 + 3 + /// local y = x * 63 + /// local z = string.len(x..", "..y) + /// "#).exec() + /// # } + /// + /// # #[cfg(not(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51")))] + /// # fn main() {} + /// ``` + /// + /// [`HookTriggers`]: struct.HookTriggers.html + /// [`HookTriggers.every_nth_instruction`]: struct.HookTriggers.html#field.every_nth_instruction + #[cfg(any( + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "lua51", + doc + ))] + pub fn set_hook(&self, triggers: HookTriggers, callback: F) + where + F: 'static + MaybeSend + FnMut(&Lua, Debug) -> Result<()>, + { + unsafe { + let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); + extra.hook_callback = Some(Arc::new(RefCell::new(callback))); + ffi::lua_sethook( + self.main_state, + Some(hook_proc), + triggers.mask(), + triggers.count(), + ); + } + } + + /// Remove any hook previously set by `set_hook`. This function has no effect if a hook was not + /// previously set. + /// + /// Requires `feature = "lua54/lua53/lua52/lua51"` + #[cfg(any( + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "lua51", + doc + ))] + pub fn remove_hook(&self) { + let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); + unsafe { + extra.hook_callback = None; + ffi::lua_sethook(self.main_state, None, 0, 0); + } + } + /// Returns the amount of memory (in bytes) currently used inside this Lua state. pub fn used_memory(&self) -> usize { let extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); @@ -1544,10 +1659,7 @@ impl Lua { let mut waker = noop_waker(); // Try to get an outer poll waker - ffi::lua_pushlightuserdata( - state, - &WAKER_REGISTRY_KEY as *const u8 as *mut c_void, - ); + ffi::lua_pushlightuserdata(state, &WAKER_REGISTRY_KEY as *const u8 as *mut c_void); ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX); if let Some(w) = get_gc_userdata::(state, -1).as_ref() { waker = (*w).clone(); @@ -1675,6 +1787,36 @@ impl Lua { Ok(()) } + + pub(crate) unsafe fn make_from_ptr(state: *mut ffi::lua_State) -> Self { + let _sg = StackGuard::new(state); + assert_stack(state, 3); + + ffi::lua_rawgetp( + state, + ffi::LUA_REGISTRYINDEX, + &EXTRA_REGISTRY_KEY as *const u8 as *mut c_void, + ); + let extra = mlua_expect!( + (*get_gc_userdata::>>(state, -1)).upgrade(), + "extra is destroyed" + ); + ffi::lua_pop(state, 1); + + Lua { + state, + main_state: get_main_state(state), + extra, + ephemeral: true, + safe: true, // TODO: Inherit the attribute + _no_ref_unwind_safe: PhantomData, + } + } + + pub(crate) unsafe fn hook_callback(&self) -> Option { + let extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); + extra.hook_callback.clone() + } } /// Returned from [`Lua::load`] and is used to finalize loading and executing Lua main chunks. diff --git a/src/types.rs b/src/types.rs index 1aa43d5..70d1139 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::os::raw::{c_int, c_void}; use std::sync::{Arc, Mutex}; use std::{fmt, mem, ptr}; @@ -7,6 +8,7 @@ use futures_core::future::LocalBoxFuture; use crate::error::Result; use crate::ffi; +use crate::hook::Debug; use crate::lua::Lua; use crate::util::{assert_stack, StackGuard}; use crate::value::MultiValue; @@ -27,6 +29,8 @@ pub(crate) type Callback<'lua, 'a> = pub(crate) type AsyncCallback<'lua, 'a> = Box) -> LocalBoxFuture<'lua, Result>> + 'a>; +pub(crate) type HookCallback = Arc Result<()>>>; + #[cfg(feature = "send")] pub trait MaybeSend: Send {} #[cfg(feature = "send")] diff --git a/tests/hooks.rs b/tests/hooks.rs new file mode 100644 index 0000000..32f4f67 --- /dev/null +++ b/tests/hooks.rs @@ -0,0 +1,230 @@ +#![cfg(any( + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "lua51" +))] + +use std::cell::RefCell; +use std::ops::Deref; +use std::str; +use std::sync::{Arc, Mutex}; + +use mlua::{Error, HookTriggers, Lua, Result, Value}; + +#[test] +fn line_counts() -> Result<()> { + let output = Arc::new(Mutex::new(Vec::new())); + let hook_output = output.clone(); + + let lua = Lua::new(); + lua.set_hook( + HookTriggers { + every_line: true, + ..Default::default() + }, + move |_lua, debug| { + hook_output.lock().unwrap().push(debug.curr_line()); + Ok(()) + }, + ); + lua.load( + r#" + local x = 2 + 3 + local y = x * 63 + local z = string.len(x..", "..y) + "#, + ) + .exec()?; + + let output = output.lock().unwrap(); + assert_eq!(*output, vec![2, 3, 4]); + + Ok(()) +} + +#[test] +fn function_calls() -> Result<()> { + let output = Arc::new(Mutex::new(Vec::new())); + let hook_output = output.clone(); + + let lua = Lua::new(); + lua.set_hook( + HookTriggers { + on_calls: true, + ..Default::default() + }, + move |_lua, debug| { + let names = debug.names(); + let source = debug.source(); + let name = names.name.map(|s| str::from_utf8(s).unwrap().to_owned()); + let what = source.what.map(|s| str::from_utf8(s).unwrap().to_owned()); + hook_output.lock().unwrap().push((name, what)); + Ok(()) + }, + ); + + lua.load( + r#" + local v = string.len("Hello World") + "#, + ) + .exec()?; + + let output = output.lock().unwrap(); + assert_eq!( + *output, + vec![ + (None, Some("main".to_string())), + (Some("len".to_string()), Some("C".to_string())) + ] + ); + + Ok(()) +} + +#[test] +fn error_within_hook() { + let lua = Lua::new(); + lua.set_hook( + HookTriggers { + every_line: true, + ..Default::default() + }, + |_lua, _debug| { + Err(Error::RuntimeError( + "Something happened in there!".to_string(), + )) + }, + ); + + let err = lua + .load("x = 1") + .exec() + .expect_err("panic didn't propagate"); + + match err { + Error::CallbackError { cause, .. } => match cause.deref() { + Error::RuntimeError(s) => assert_eq!(s, "Something happened in there!"), + _ => panic!("wrong callback error kind caught"), + }, + _ => panic!("wrong error kind caught"), + }; +} + +#[test] +fn limit_execution_instructions() { + let lua = Lua::new(); + let mut max_instructions = 10000; + + lua.set_hook( + HookTriggers { + every_nth_instruction: Some(30), + ..Default::default() + }, + move |_lua, _debug| { + max_instructions -= 30; + if max_instructions < 0 { + Err(Error::RuntimeError("time's up".to_string())) + } else { + Ok(()) + } + }, + ); + + lua.globals().set("x", Value::Integer(0)).unwrap(); + let _ = lua + .load( + r#" + for i = 1, 10000 do + x = x + 1 + end + "#, + ) + .exec() + .expect_err("instruction limit didn't occur"); +} + +#[test] +fn hook_removal() { + let lua = Lua::new(); + + lua.set_hook( + HookTriggers { + every_nth_instruction: Some(1), + ..Default::default() + }, + |_lua, _debug| { + Err(Error::RuntimeError( + "this hook should've been removed by this time".to_string(), + )) + }, + ); + + assert!(lua.load("local x = 1").exec().is_err()); + lua.remove_hook(); + assert!(lua.load("local x = 1").exec().is_ok()); +} + +#[test] +fn hook_swap_within_hook() { + thread_local! { + static TL_LUA: RefCell> = RefCell::new(None); + } + + TL_LUA.with(|tl| { + *tl.borrow_mut() = Some(Lua::new()); + }); + + TL_LUA.with(|tl| { + tl.borrow().as_ref().unwrap().set_hook( + HookTriggers { + every_line: true, + ..Default::default() + }, + move |lua, _debug| { + lua.globals().set("ok", 1i64).unwrap(); + TL_LUA.with(|tl| { + tl.borrow().as_ref().unwrap().set_hook( + HookTriggers { + every_line: true, + ..Default::default() + }, + move |lua, _debug| { + lua.load( + r#" + if ok ~= nil then + ok = ok + 1 + end + "#, + ) + .exec() + .expect("exec failure within hook"); + TL_LUA.with(|tl| { + tl.borrow().as_ref().unwrap().remove_hook(); + }); + Ok(()) + }, + ); + }); + Ok(()) + }, + ); + }); + + TL_LUA.with(|tl| { + let tl = tl.borrow(); + let lua = tl.as_ref().unwrap(); + assert!(lua + .load( + r#" + local x = 1 + x = 2 + local y = 3 + "#, + ) + .exec() + .is_ok()); + assert_eq!(lua.globals().get::<_, i64>("ok").unwrap_or(-1), 2); + }); +}