From 595bc3a2b3bd44323d7bfdb0afbadcdb9b854868 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Wed, 30 Mar 2022 22:01:06 +0100 Subject: [PATCH] Support Luau interrupts (closes #138) --- src/lib.rs | 2 +- src/lua.rs | 148 ++++++++++++++++++++++++++++++++++++++++++++++--- src/prelude.rs | 4 ++ src/types.rs | 14 +++++ tests/luau.rs | 74 ++++++++++++++++++++++++- 5 files changed, 232 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0ea54ac..a591fd6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,7 +126,7 @@ pub use crate::hook::HookTriggers; #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] -pub use crate::chunk::Compiler; +pub use crate::{chunk::Compiler, types::VmState}; #[cfg(feature = "async")] pub use crate::thread::AsyncThread; diff --git a/src/lua.rs b/src/lua.rs index 9bfbebb..e82a4ca 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -47,6 +47,9 @@ use { #[cfg(not(feature = "luau"))] use crate::{hook::HookTriggers, types::HookCallback}; +#[cfg(feature = "luau")] +use crate::types::{InterruptCallback, VmState}; + #[cfg(feature = "async")] use { crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}, @@ -108,6 +111,8 @@ struct ExtraData { hook_callback: Option, #[cfg(feature = "lua54")] warn_callback: Option, + #[cfg(feature = "luau")] + interrupt_callback: Option, #[cfg(feature = "luau")] sandboxed: bool, @@ -235,6 +240,13 @@ impl Drop for Lua { ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx); extra.ref_free.push(extra.ref_waker_idx); } + #[cfg(feature = "luau")] + { + let callbacks = ffi::lua_callbacks(self.state); + let extra_ptr = (*callbacks).userdata as *mut Arc>; + drop(Box::from_raw(extra_ptr)); + (*callbacks).userdata = ptr::null_mut(); + } mlua_debug_assert!( ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top && extra.ref_stack_top as usize == extra.ref_free.len(), @@ -552,6 +564,8 @@ impl Lua { #[cfg(feature = "lua54")] warn_callback: None, #[cfg(feature = "luau")] + interrupt_callback: None, + #[cfg(feature = "luau")] sandboxed: false, })); @@ -581,6 +595,14 @@ impl Lua { ); assert_stack(main_state, ffi::LUA_MINSTACK); + // Set Luau callbacks userdata to extra data + // We can use global callbacks userdata since we don't allow C modules in Luau + #[cfg(feature = "luau")] + { + let extra_raw = Box::into_raw(Box::new(Arc::clone(&extra))); + (*ffi::lua_callbacks(main_state)).userdata = extra_raw as *mut c_void; + } + Lua { state, main_state: maybe_main_state, @@ -895,6 +917,102 @@ impl Lua { } } + /// Sets an 'interrupt' function that will periodically be called by Luau VM. + /// + /// Any Luau code is guaranteed to call this handler "eventually" + /// (in practice this can happen at any function call or at any loop iteration). + /// + /// The provided interrupt function can error, and this error will be propagated through + /// the Luau code that was executing at the time the interrupt was triggered. + /// Also this can be used to implement continuous execution limits by instructing Luau VM to yield + /// by returning [`VmState::Yield`]. + /// + /// This is similar to [`Lua::set_hook`] but in more simplified form. + /// + /// # Example + /// + /// Periodically yield Luau VM to suspend execution. + /// + /// ``` + /// # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; + /// # use mlua::{Lua, Result, ThreadStatus, VmState}; + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// let count = Arc::new(AtomicU64::new(0)); + /// lua.set_interrupt(move |_lua| { + /// if count.fetch_add(1, Ordering::Relaxed) % 2 == 0 { + /// return Ok(VmState::Yield); + /// } + /// Ok(VmState::Continue) + /// }); + /// + /// let co = lua.create_thread( + /// lua.load(r#" + /// local b = 0 + /// for _, x in ipairs({1, 2, 3}) do b += x end + /// "#) + /// .into_function()?, + /// )?; + /// while co.status() == ThreadStatus::Resumable { + /// co.resume(())?; + /// } + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "luau")] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_interrupt<'lua, F>(&'lua self, callback: F) + where + F: 'static + MaybeSend + Fn(&Lua) -> Result, + { + unsafe extern "C" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) { + if gc != -1 { + // We don't support GC interrupts since they cannot survive Lua exceptions + return; + } + // TODO: think about not using drop types here + let lua = match Lua::make_from_ptr(state) { + Some(lua) => lua, + None => return, + }; + let extra = lua.extra.get(); + let result = callback_error_ext(state, extra, move |_| { + let interrupt_cb = (*extra).interrupt_callback.clone(); + let interrupt_cb = + mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc"); + if Arc::strong_count(&interrupt_cb) > 2 { + return Ok(VmState::Continue); // Don't allow recursion + } + interrupt_cb(&lua) + }); + match result { + VmState::Continue => {} + VmState::Yield => { + ffi::lua_yield(state, 0); + } + } + } + + let state = mlua_expect!(self.main_state, "Luau should always has main state"); + unsafe { + (*self.extra.get()).interrupt_callback = Some(Arc::new(callback)); + (*ffi::lua_callbacks(state)).interrupt = Some(interrupt_proc); + } + } + + /// Removes any 'interrupt' previously set by `set_interrupt`. + /// + /// This function has no effect if an 'interrupt' was not previously set. + #[cfg(feature = "luau")] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn remove_interrupt(&self) { + let state = mlua_expect!(self.main_state, "Luau should always has main state"); + unsafe { + (*self.extra.get()).interrupt_callback = None; + (*ffi::lua_callbacks(state)).interrupt = None; + } + } + /// Sets the warning function to be used by Lua to emit warnings. /// /// Requires `feature = "lua54"` @@ -2759,14 +2877,7 @@ impl Lua { let _sg = StackGuard::new(state); assert_stack(state, 1); - let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; - if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA { - return None; - } - let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Arc>; - let extra = Arc::clone(&*extra_ptr); - ffi::lua_pop(state, 1); - + let extra = extra_data(state)?; let safe = (*extra.get()).safe; Some(Lua { state, @@ -2798,6 +2909,27 @@ impl Lua { } } +#[cfg(feature = "luau")] +unsafe fn extra_data(state: *mut ffi::lua_State) -> Option>> { + let extra_ptr = (*ffi::lua_callbacks(state)).userdata as *mut Arc>; + if extra_ptr.is_null() { + return None; + } + Some(Arc::clone(&*extra_ptr)) +} + +#[cfg(not(feature = "luau"))] +unsafe fn extra_data(state: *mut ffi::lua_State) -> Option>> { + let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; + if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA { + return None; + } + let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Arc>; + let extra = Arc::clone(&*extra_ptr); + ffi::lua_pop(state, 1); + Some(extra) +} + // Creates required entries in the metatable cache (see `util::METATABLE_CACHE`) pub(crate) fn init_metatable_cache(cache: &mut FxHashMap) { cache.insert(TypeId::of::>>(), 0); diff --git a/src/prelude.rs b/src/prelude.rs index 1abc7e7..28a90e5 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -15,6 +15,10 @@ pub use crate::{ Value as LuaValue, }; +#[cfg(feature = "luau")] +#[doc(no_inline)] +pub use crate::VmState as LuaVmState; + #[cfg(feature = "async")] #[doc(no_inline)] pub use crate::AsyncThread as LuaAsyncThread; diff --git a/src/types.rs b/src/types.rs index ed7305b..6ff7a21 100644 --- a/src/types.rs +++ b/src/types.rs @@ -55,6 +55,20 @@ pub(crate) type HookCallback = Arc Result<()> + #[cfg(all(not(feature = "send"), not(feature = "luau")))] pub(crate) type HookCallback = Arc Result<()>>>; +/// Type to set next Lua VM action after executing interrupt function. +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +pub enum VmState { + Continue, + Yield, +} + +#[cfg(all(feature = "luau", feature = "send"))] +pub(crate) type InterruptCallback = Arc Result + Send>; + +#[cfg(all(feature = "luau", not(feature = "send")))] +pub(crate) type InterruptCallback = Arc Result>; + #[cfg(all(feature = "send", feature = "lua54"))] pub(crate) type WarnCallback = Box Result<()> + Send>; diff --git a/tests/luau.rs b/tests/luau.rs index ac6c290..608ea98 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -2,8 +2,10 @@ use std::env; use std::fs; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; -use mlua::{Error, Lua, Result, Table, Value}; +use mlua::{Error, Lua, Result, Table, ThreadStatus, Value, VmState}; #[test] fn test_require() -> Result<()> { @@ -125,3 +127,73 @@ fn test_sandbox_threads() -> Result<()> { Ok(()) } + +#[test] +fn test_interrupts() -> Result<()> { + let lua = Lua::new(); + + let interrupts_count = Arc::new(AtomicU64::new(0)); + let interrupts_count2 = interrupts_count.clone(); + + lua.set_interrupt(move |_lua| { + interrupts_count2.fetch_add(1, Ordering::Relaxed); + Ok(VmState::Continue) + }); + let f = lua + .load( + r#" + local x = 2 + 3 + local y = x * 63 + local z = string.len(x..", "..y) + "#, + ) + .into_function()?; + f.call(())?; + + assert!(interrupts_count.load(Ordering::Relaxed) > 0); + + // + // Test yields from interrupt + // + let yield_count = Arc::new(AtomicU64::new(0)); + let yield_count2 = yield_count.clone(); + lua.set_interrupt(move |_lua| { + if yield_count2.fetch_add(1, Ordering::Relaxed) == 1 { + return Ok(VmState::Yield); + } + Ok(VmState::Continue) + }); + let co = lua.create_thread( + lua.load( + r#" + local a = {1, 2, 3} + local b = 0 + for _, x in ipairs(a) do b += x end + return b + "#, + ) + .into_function()?, + )?; + co.resume(())?; + assert_eq!(co.status(), ThreadStatus::Resumable); + let result: i32 = co.resume(())?; + assert_eq!(result, 6); + assert_eq!(yield_count.load(Ordering::Relaxed), 7); + assert_eq!(co.status(), ThreadStatus::Unresumable); + + // + // Test errors in interrupts + // + lua.set_interrupt(|_| Err(Error::RuntimeError("error from interrupt".into()))); + match f.call::<_, ()>(()) { + Err(Error::CallbackError { cause, .. }) => match *cause { + Error::RuntimeError(ref m) if m == "error from interrupt" => {} + ref e => panic!("expected RuntimeError with a specific message, got {:?}", e), + }, + r => panic!("expected CallbackError, got {:?}", r), + } + + lua.remove_interrupt(); + + Ok(()) +}