Support Luau interrupts (closes #138)
This commit is contained in:
parent
87c10ca93d
commit
595bc3a2b3
|
@ -126,7 +126,7 @@ pub use crate::hook::HookTriggers;
|
||||||
|
|
||||||
#[cfg(any(feature = "luau", doc))]
|
#[cfg(any(feature = "luau", doc))]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
|
||||||
pub use crate::chunk::Compiler;
|
pub use crate::{chunk::Compiler, types::VmState};
|
||||||
|
|
||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
pub use crate::thread::AsyncThread;
|
pub use crate::thread::AsyncThread;
|
||||||
|
|
148
src/lua.rs
148
src/lua.rs
|
@ -47,6 +47,9 @@ use {
|
||||||
#[cfg(not(feature = "luau"))]
|
#[cfg(not(feature = "luau"))]
|
||||||
use crate::{hook::HookTriggers, types::HookCallback};
|
use crate::{hook::HookTriggers, types::HookCallback};
|
||||||
|
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
use crate::types::{InterruptCallback, VmState};
|
||||||
|
|
||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
use {
|
use {
|
||||||
crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue},
|
crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue},
|
||||||
|
@ -108,6 +111,8 @@ struct ExtraData {
|
||||||
hook_callback: Option<HookCallback>,
|
hook_callback: Option<HookCallback>,
|
||||||
#[cfg(feature = "lua54")]
|
#[cfg(feature = "lua54")]
|
||||||
warn_callback: Option<WarnCallback>,
|
warn_callback: Option<WarnCallback>,
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
interrupt_callback: Option<InterruptCallback>,
|
||||||
|
|
||||||
#[cfg(feature = "luau")]
|
#[cfg(feature = "luau")]
|
||||||
sandboxed: bool,
|
sandboxed: bool,
|
||||||
|
@ -235,6 +240,13 @@ impl Drop for Lua {
|
||||||
ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
|
ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
|
||||||
extra.ref_free.push(extra.ref_waker_idx);
|
extra.ref_free.push(extra.ref_waker_idx);
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
{
|
||||||
|
let callbacks = ffi::lua_callbacks(self.state);
|
||||||
|
let extra_ptr = (*callbacks).userdata as *mut Arc<UnsafeCell<ExtraData>>;
|
||||||
|
drop(Box::from_raw(extra_ptr));
|
||||||
|
(*callbacks).userdata = ptr::null_mut();
|
||||||
|
}
|
||||||
mlua_debug_assert!(
|
mlua_debug_assert!(
|
||||||
ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
|
ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
|
||||||
&& extra.ref_stack_top as usize == extra.ref_free.len(),
|
&& extra.ref_stack_top as usize == extra.ref_free.len(),
|
||||||
|
@ -552,6 +564,8 @@ impl Lua {
|
||||||
#[cfg(feature = "lua54")]
|
#[cfg(feature = "lua54")]
|
||||||
warn_callback: None,
|
warn_callback: None,
|
||||||
#[cfg(feature = "luau")]
|
#[cfg(feature = "luau")]
|
||||||
|
interrupt_callback: None,
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
sandboxed: false,
|
sandboxed: false,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
@ -581,6 +595,14 @@ impl Lua {
|
||||||
);
|
);
|
||||||
assert_stack(main_state, ffi::LUA_MINSTACK);
|
assert_stack(main_state, ffi::LUA_MINSTACK);
|
||||||
|
|
||||||
|
// Set Luau callbacks userdata to extra data
|
||||||
|
// We can use global callbacks userdata since we don't allow C modules in Luau
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
{
|
||||||
|
let extra_raw = Box::into_raw(Box::new(Arc::clone(&extra)));
|
||||||
|
(*ffi::lua_callbacks(main_state)).userdata = extra_raw as *mut c_void;
|
||||||
|
}
|
||||||
|
|
||||||
Lua {
|
Lua {
|
||||||
state,
|
state,
|
||||||
main_state: maybe_main_state,
|
main_state: maybe_main_state,
|
||||||
|
@ -895,6 +917,102 @@ impl Lua {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sets an 'interrupt' function that will periodically be called by Luau VM.
|
||||||
|
///
|
||||||
|
/// Any Luau code is guaranteed to call this handler "eventually"
|
||||||
|
/// (in practice this can happen at any function call or at any loop iteration).
|
||||||
|
///
|
||||||
|
/// The provided interrupt function can error, and this error will be propagated through
|
||||||
|
/// the Luau code that was executing at the time the interrupt was triggered.
|
||||||
|
/// Also this can be used to implement continuous execution limits by instructing Luau VM to yield
|
||||||
|
/// by returning [`VmState::Yield`].
|
||||||
|
///
|
||||||
|
/// This is similar to [`Lua::set_hook`] but in more simplified form.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// Periodically yield Luau VM to suspend execution.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
|
||||||
|
/// # use mlua::{Lua, Result, ThreadStatus, VmState};
|
||||||
|
/// # fn main() -> Result<()> {
|
||||||
|
/// let lua = Lua::new();
|
||||||
|
/// let count = Arc::new(AtomicU64::new(0));
|
||||||
|
/// lua.set_interrupt(move |_lua| {
|
||||||
|
/// if count.fetch_add(1, Ordering::Relaxed) % 2 == 0 {
|
||||||
|
/// return Ok(VmState::Yield);
|
||||||
|
/// }
|
||||||
|
/// Ok(VmState::Continue)
|
||||||
|
/// });
|
||||||
|
///
|
||||||
|
/// let co = lua.create_thread(
|
||||||
|
/// lua.load(r#"
|
||||||
|
/// local b = 0
|
||||||
|
/// for _, x in ipairs({1, 2, 3}) do b += x end
|
||||||
|
/// "#)
|
||||||
|
/// .into_function()?,
|
||||||
|
/// )?;
|
||||||
|
/// while co.status() == ThreadStatus::Resumable {
|
||||||
|
/// co.resume(())?;
|
||||||
|
/// }
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
|
||||||
|
pub fn set_interrupt<'lua, F>(&'lua self, callback: F)
|
||||||
|
where
|
||||||
|
F: 'static + MaybeSend + Fn(&Lua) -> Result<VmState>,
|
||||||
|
{
|
||||||
|
unsafe extern "C" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) {
|
||||||
|
if gc != -1 {
|
||||||
|
// We don't support GC interrupts since they cannot survive Lua exceptions
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// TODO: think about not using drop types here
|
||||||
|
let lua = match Lua::make_from_ptr(state) {
|
||||||
|
Some(lua) => lua,
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
let extra = lua.extra.get();
|
||||||
|
let result = callback_error_ext(state, extra, move |_| {
|
||||||
|
let interrupt_cb = (*extra).interrupt_callback.clone();
|
||||||
|
let interrupt_cb =
|
||||||
|
mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc");
|
||||||
|
if Arc::strong_count(&interrupt_cb) > 2 {
|
||||||
|
return Ok(VmState::Continue); // Don't allow recursion
|
||||||
|
}
|
||||||
|
interrupt_cb(&lua)
|
||||||
|
});
|
||||||
|
match result {
|
||||||
|
VmState::Continue => {}
|
||||||
|
VmState::Yield => {
|
||||||
|
ffi::lua_yield(state, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let state = mlua_expect!(self.main_state, "Luau should always has main state");
|
||||||
|
unsafe {
|
||||||
|
(*self.extra.get()).interrupt_callback = Some(Arc::new(callback));
|
||||||
|
(*ffi::lua_callbacks(state)).interrupt = Some(interrupt_proc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Removes any 'interrupt' previously set by `set_interrupt`.
|
||||||
|
///
|
||||||
|
/// This function has no effect if an 'interrupt' was not previously set.
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
|
||||||
|
pub fn remove_interrupt(&self) {
|
||||||
|
let state = mlua_expect!(self.main_state, "Luau should always has main state");
|
||||||
|
unsafe {
|
||||||
|
(*self.extra.get()).interrupt_callback = None;
|
||||||
|
(*ffi::lua_callbacks(state)).interrupt = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Sets the warning function to be used by Lua to emit warnings.
|
/// Sets the warning function to be used by Lua to emit warnings.
|
||||||
///
|
///
|
||||||
/// Requires `feature = "lua54"`
|
/// Requires `feature = "lua54"`
|
||||||
|
@ -2759,14 +2877,7 @@ impl Lua {
|
||||||
let _sg = StackGuard::new(state);
|
let _sg = StackGuard::new(state);
|
||||||
assert_stack(state, 1);
|
assert_stack(state, 1);
|
||||||
|
|
||||||
let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void;
|
let extra = extra_data(state)?;
|
||||||
if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Arc<UnsafeCell<ExtraData>>;
|
|
||||||
let extra = Arc::clone(&*extra_ptr);
|
|
||||||
ffi::lua_pop(state, 1);
|
|
||||||
|
|
||||||
let safe = (*extra.get()).safe;
|
let safe = (*extra.get()).safe;
|
||||||
Some(Lua {
|
Some(Lua {
|
||||||
state,
|
state,
|
||||||
|
@ -2798,6 +2909,27 @@ impl Lua {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
unsafe fn extra_data(state: *mut ffi::lua_State) -> Option<Arc<UnsafeCell<ExtraData>>> {
|
||||||
|
let extra_ptr = (*ffi::lua_callbacks(state)).userdata as *mut Arc<UnsafeCell<ExtraData>>;
|
||||||
|
if extra_ptr.is_null() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(Arc::clone(&*extra_ptr))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "luau"))]
|
||||||
|
unsafe fn extra_data(state: *mut ffi::lua_State) -> Option<Arc<UnsafeCell<ExtraData>>> {
|
||||||
|
let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void;
|
||||||
|
if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Arc<UnsafeCell<ExtraData>>;
|
||||||
|
let extra = Arc::clone(&*extra_ptr);
|
||||||
|
ffi::lua_pop(state, 1);
|
||||||
|
Some(extra)
|
||||||
|
}
|
||||||
|
|
||||||
// Creates required entries in the metatable cache (see `util::METATABLE_CACHE`)
|
// Creates required entries in the metatable cache (see `util::METATABLE_CACHE`)
|
||||||
pub(crate) fn init_metatable_cache(cache: &mut FxHashMap<TypeId, u8>) {
|
pub(crate) fn init_metatable_cache(cache: &mut FxHashMap<TypeId, u8>) {
|
||||||
cache.insert(TypeId::of::<Arc<UnsafeCell<ExtraData>>>(), 0);
|
cache.insert(TypeId::of::<Arc<UnsafeCell<ExtraData>>>(), 0);
|
||||||
|
|
|
@ -15,6 +15,10 @@ pub use crate::{
|
||||||
Value as LuaValue,
|
Value as LuaValue,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "luau")]
|
||||||
|
#[doc(no_inline)]
|
||||||
|
pub use crate::VmState as LuaVmState;
|
||||||
|
|
||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
#[doc(no_inline)]
|
#[doc(no_inline)]
|
||||||
pub use crate::AsyncThread as LuaAsyncThread;
|
pub use crate::AsyncThread as LuaAsyncThread;
|
||||||
|
|
14
src/types.rs
14
src/types.rs
|
@ -55,6 +55,20 @@ pub(crate) type HookCallback = Arc<Mutex<dyn FnMut(&Lua, Debug) -> Result<()> +
|
||||||
#[cfg(all(not(feature = "send"), not(feature = "luau")))]
|
#[cfg(all(not(feature = "send"), not(feature = "luau")))]
|
||||||
pub(crate) type HookCallback = Arc<Mutex<dyn FnMut(&Lua, Debug) -> Result<()>>>;
|
pub(crate) type HookCallback = Arc<Mutex<dyn FnMut(&Lua, Debug) -> Result<()>>>;
|
||||||
|
|
||||||
|
/// Type to set next Lua VM action after executing interrupt function.
|
||||||
|
#[cfg(any(feature = "luau", doc))]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
|
||||||
|
pub enum VmState {
|
||||||
|
Continue,
|
||||||
|
Yield,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(feature = "luau", feature = "send"))]
|
||||||
|
pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState> + Send>;
|
||||||
|
|
||||||
|
#[cfg(all(feature = "luau", not(feature = "send")))]
|
||||||
|
pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState>>;
|
||||||
|
|
||||||
#[cfg(all(feature = "send", feature = "lua54"))]
|
#[cfg(all(feature = "send", feature = "lua54"))]
|
||||||
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &CStr, bool) -> Result<()> + Send>;
|
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &CStr, bool) -> Result<()> + Send>;
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,10 @@
|
||||||
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use mlua::{Error, Lua, Result, Table, Value};
|
use mlua::{Error, Lua, Result, Table, ThreadStatus, Value, VmState};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_require() -> Result<()> {
|
fn test_require() -> Result<()> {
|
||||||
|
@ -125,3 +127,73 @@ fn test_sandbox_threads() -> Result<()> {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_interrupts() -> Result<()> {
|
||||||
|
let lua = Lua::new();
|
||||||
|
|
||||||
|
let interrupts_count = Arc::new(AtomicU64::new(0));
|
||||||
|
let interrupts_count2 = interrupts_count.clone();
|
||||||
|
|
||||||
|
lua.set_interrupt(move |_lua| {
|
||||||
|
interrupts_count2.fetch_add(1, Ordering::Relaxed);
|
||||||
|
Ok(VmState::Continue)
|
||||||
|
});
|
||||||
|
let f = lua
|
||||||
|
.load(
|
||||||
|
r#"
|
||||||
|
local x = 2 + 3
|
||||||
|
local y = x * 63
|
||||||
|
local z = string.len(x..", "..y)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.into_function()?;
|
||||||
|
f.call(())?;
|
||||||
|
|
||||||
|
assert!(interrupts_count.load(Ordering::Relaxed) > 0);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Test yields from interrupt
|
||||||
|
//
|
||||||
|
let yield_count = Arc::new(AtomicU64::new(0));
|
||||||
|
let yield_count2 = yield_count.clone();
|
||||||
|
lua.set_interrupt(move |_lua| {
|
||||||
|
if yield_count2.fetch_add(1, Ordering::Relaxed) == 1 {
|
||||||
|
return Ok(VmState::Yield);
|
||||||
|
}
|
||||||
|
Ok(VmState::Continue)
|
||||||
|
});
|
||||||
|
let co = lua.create_thread(
|
||||||
|
lua.load(
|
||||||
|
r#"
|
||||||
|
local a = {1, 2, 3}
|
||||||
|
local b = 0
|
||||||
|
for _, x in ipairs(a) do b += x end
|
||||||
|
return b
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.into_function()?,
|
||||||
|
)?;
|
||||||
|
co.resume(())?;
|
||||||
|
assert_eq!(co.status(), ThreadStatus::Resumable);
|
||||||
|
let result: i32 = co.resume(())?;
|
||||||
|
assert_eq!(result, 6);
|
||||||
|
assert_eq!(yield_count.load(Ordering::Relaxed), 7);
|
||||||
|
assert_eq!(co.status(), ThreadStatus::Unresumable);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Test errors in interrupts
|
||||||
|
//
|
||||||
|
lua.set_interrupt(|_| Err(Error::RuntimeError("error from interrupt".into())));
|
||||||
|
match f.call::<_, ()>(()) {
|
||||||
|
Err(Error::CallbackError { cause, .. }) => match *cause {
|
||||||
|
Error::RuntimeError(ref m) if m == "error from interrupt" => {}
|
||||||
|
ref e => panic!("expected RuntimeError with a specific message, got {:?}", e),
|
||||||
|
},
|
||||||
|
r => panic!("expected CallbackError, got {:?}", r),
|
||||||
|
}
|
||||||
|
|
||||||
|
lua.remove_interrupt();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue