Refactor main_state handling

Don't allow to set hook if main_state is not available
Remove Lua 5.1 dirty hack
This commit is contained in:
Alex Orlenko 2020-06-07 15:16:12 +01:00
parent 2eb40deafd
commit 3d42bc4ca6
6 changed files with 90 additions and 202 deletions

View File

@ -43,6 +43,11 @@ pub enum Error {
/// This error can only happen when Lua state was not created by us and does not have the /// This error can only happen when Lua state was not created by us and does not have the
/// custom allocator attached. /// custom allocator attached.
MemoryLimitNotAvailable, MemoryLimitNotAvailable,
/// Main thread is not available.
///
/// This error can only happen in Lua5.1/LuaJIT module mode, when module loaded within a coroutine.
/// These Lua versions does not have `LUA_RIDX_MAINTHREAD` registry key.
MainThreadNotAvailable,
/// A mutable callback has triggered Lua code that has called the same mutable callback again. /// A mutable callback has triggered Lua code that has called the same mutable callback again.
/// ///
/// This is an error because a mutable callback can only be borrowed mutably once. /// This is an error because a mutable callback can only be borrowed mutably once.
@ -161,6 +166,9 @@ impl fmt::Display for Error {
Error::MemoryLimitNotAvailable => { Error::MemoryLimitNotAvailable => {
write!(fmt, "setting memory limit is not available") write!(fmt, "setting memory limit is not available")
} }
Error::MainThreadNotAvailable => {
write!(fmt, "main thread is not available in Lua 5.1")
}
Error::RecursiveMutCallback => write!(fmt, "mutable callback called recursively"), Error::RecursiveMutCallback => write!(fmt, "mutable callback called recursively"),
Error::CallbackDestructed => write!( Error::CallbackDestructed => write!(
fmt, fmt,

View File

@ -1,120 +0,0 @@
// The MIT License (MIT)
//
// Copyright (c) 2020 A. Orlenko
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
use std::os::raw::*;
use crate::ffi::{lua_Alloc, lua_CFunction, lua_Hook, lua_Number, lua_State};
#[repr(C)]
struct lua_StateExt {
next: *mut c_void,
tt: u8,
marked: u8,
status: u8,
top: *mut c_void,
base: *mut c_void,
l_G: *mut global_State,
ci: *mut c_void,
savedpc: *const c_void,
stack_last: *mut c_void,
stack: *mut c_void,
end_ci: *mut c_void,
base_ci: *mut c_void,
stacksize: c_int,
size_ci: c_int,
nCcalls: c_ushort,
baseCcalls: c_ushort,
hookmask: u8,
allowhook: u8,
basehookcount: c_int,
hookcount: c_int,
hook: Option<lua_Hook>,
l_gt: TValue,
env: TValue,
openupval: *mut c_void,
gclist: *mut c_void,
errorJmp: *mut c_void,
errfunc: isize,
}
#[repr(C)]
#[derive(Clone, Copy)]
struct TValue {
value: Value,
tt: c_int,
}
#[repr(C)]
#[derive(Clone, Copy)]
union Value {
gc: *mut c_void,
p: *mut c_void,
n: lua_Number,
b: c_int,
}
#[repr(C)]
struct global_State {
strt: stringtable,
frealloc: Option<lua_Alloc>,
ud: *mut c_void,
currentwhite: u8,
gcstate: u8,
sweepstrgc: c_int,
rootgc: *mut c_void,
sweepgc: *mut c_void,
gray: *mut c_void,
grayagain: *mut c_void,
weak: *mut c_void,
tmudata: *mut c_void,
buff: Mbuffer,
GCthreshold: usize,
totalbytes: usize,
estimate: usize,
gcdept: usize,
gcpause: c_int,
gcstepmul: c_int,
panic: Option<lua_CFunction>,
l_registry: TValue,
mainthread: *mut lua_State,
// Other fields ommited
}
#[repr(C)]
struct stringtable {
hash: *mut c_void,
nuse: c_uint,
size: c_int,
}
#[repr(C)]
struct Mbuffer {
buffer: *mut c_char,
n: usize,
buffsize: usize,
}
pub unsafe fn lua_getmainstate(state: *mut lua_State) -> *mut lua_State {
let state = state as *mut lua_StateExt;
let global = (*state).l_G;
(*global).mainthread
}

View File

@ -170,9 +170,6 @@ pub use self::lua::{lua_isyieldable, lua_version};
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
pub use self::lua::{lua_callk, lua_pcallk, lua_upvalueid, lua_upvaluejoin, lua_yieldk}; pub use self::lua::{lua_callk, lua_pcallk, lua_upvalueid, lua_upvaluejoin, lua_yieldk};
#[cfg(feature = "lua51")]
pub use self::internals51::lua_getmainstate;
// auxiliary library types // auxiliary library types
pub use self::lauxlib::luaL_Reg; pub use self::lauxlib::luaL_Reg;
@ -287,9 +284,6 @@ mod glue {
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
mod compat53; mod compat53;
#[cfg(feature = "lua51")]
mod internals51;
mod lauxlib; mod lauxlib;
mod lua; mod lua;
mod luaconf; mod luaconf;

View File

@ -50,7 +50,7 @@ use {
/// Top level Lua struct which holds the Lua state itself. /// Top level Lua struct which holds the Lua state itself.
pub struct Lua { pub struct Lua {
pub(crate) state: *mut ffi::lua_State, pub(crate) state: *mut ffi::lua_State,
main_state: *mut ffi::lua_State, main_state: Option<*mut ffi::lua_State>,
extra: Arc<Mutex<ExtraData>>, extra: Arc<Mutex<ExtraData>>,
ephemeral: bool, ephemeral: bool,
safe: bool, safe: bool,
@ -115,7 +115,7 @@ impl Drop for Lua {
"reference leak detected" "reference leak detected"
); );
*mlua_expect!(extra.registry_unref_list.lock(), "unref list poisoned") = None; *mlua_expect!(extra.registry_unref_list.lock(), "unref list poisoned") = None;
ffi::lua_close(self.main_state); ffi::lua_close(self.main_state.expect("main_state is null"));
if !extra.mem_info.is_null() { if !extra.mem_info.is_null() {
Box::from_raw(extra.mem_info); Box::from_raw(extra.mem_info);
} }
@ -266,7 +266,7 @@ impl Lua {
} }
mlua_expect!( mlua_expect!(
protect_lua_closure(lua.main_state, 0, 0, |state| { protect_lua_closure(lua.main_state.expect("main_state is null"), 0, 0, |state| {
load_from_std_lib(state, libs); load_from_std_lib(state, libs);
}), }),
"Error during loading standard libraries" "Error during loading standard libraries"
@ -278,7 +278,8 @@ impl Lua {
/// Constructs a new Lua instance from an existing raw state. /// Constructs a new Lua instance from an existing raw state.
#[allow(clippy::missing_safety_doc)] #[allow(clippy::missing_safety_doc)]
pub unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> Lua { pub unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> Lua {
let main_state = get_main_state(state); let maybe_main_state = get_main_state(state);
let main_state = maybe_main_state.unwrap_or(state);
let main_state_top = ffi::lua_gettop(main_state); let main_state_top = ffi::lua_gettop(main_state);
let ref_thread = mlua_expect!( let ref_thread = mlua_expect!(
@ -346,7 +347,7 @@ impl Lua {
Lua { Lua {
state, state,
main_state, main_state: maybe_main_state,
extra, extra,
ephemeral: true, ephemeral: true,
safe: false, safe: false,
@ -374,8 +375,9 @@ impl Lua {
} }
} }
let state = self.main_state.unwrap_or(self.state);
unsafe { unsafe {
protect_lua_closure(self.main_state, 0, 0, |state| { protect_lua_closure(state, 0, 0, |state| {
load_from_std_lib(state, libs); load_from_std_lib(state, libs);
}) })
} }
@ -445,7 +447,7 @@ impl Lua {
/// }, |_lua, debug| { /// }, |_lua, debug| {
/// println!("line {}", debug.curr_line()); /// println!("line {}", debug.curr_line());
/// Ok(()) /// Ok(())
/// }); /// })?;
/// ///
/// lua.load(r#" /// lua.load(r#"
/// local x = 2 + 3 /// local x = 2 + 3
@ -467,20 +469,17 @@ impl Lua {
feature = "lua51", feature = "lua51",
doc doc
))] ))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F) pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
where where
F: 'static + MaybeSend + FnMut(&Lua, Debug) -> Result<()>, F: 'static + MaybeSend + FnMut(&Lua, Debug) -> Result<()>,
{ {
let state = self.main_state.ok_or(Error::MainThreadNotAvailable)?;
unsafe { unsafe {
let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
extra.hook_callback = Some(Arc::new(RefCell::new(callback))); extra.hook_callback = Some(Arc::new(RefCell::new(callback)));
ffi::lua_sethook( ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count());
self.main_state,
Some(hook_proc),
triggers.mask(),
triggers.count(),
);
} }
Ok(())
} }
/// Remove any hook previously set by `set_hook`. This function has no effect if a hook was not /// Remove any hook previously set by `set_hook`. This function has no effect if a hook was not
@ -495,21 +494,27 @@ impl Lua {
doc doc
))] ))]
pub fn remove_hook(&self) { pub fn remove_hook(&self) {
// If main_state is not available, then sethook wasn't called.
let state = match self.main_state {
Some(state) => state,
None => return,
};
let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); let mut extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
unsafe { unsafe {
extra.hook_callback = None; extra.hook_callback = None;
ffi::lua_sethook(self.main_state, None, 0, 0); ffi::lua_sethook(state, None, 0, 0);
} }
} }
/// Returns the amount of memory (in bytes) currently used inside this Lua state. /// Returns the amount of memory (in bytes) currently used inside this Lua state.
pub fn used_memory(&self) -> usize { pub fn used_memory(&self) -> usize {
let extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); let extra = mlua_expect!(self.extra.lock(), "extra is poisoned");
let state = self.main_state.unwrap_or(self.state);
if extra.mem_info.is_null() { if extra.mem_info.is_null() {
// Get data from the Lua GC // Get data from the Lua GC
unsafe { unsafe {
let used_kbytes = ffi::lua_gc(self.main_state, ffi::LUA_GCCOUNT, 0); let used_kbytes = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0);
let used_kbytes_rem = ffi::lua_gc(self.main_state, ffi::LUA_GCCOUNTB, 0); let used_kbytes_rem = ffi::lua_gc(state, ffi::LUA_GCCOUNTB, 0);
return (used_kbytes as usize) * 1024 + (used_kbytes_rem as usize); return (used_kbytes as usize) * 1024 + (used_kbytes_rem as usize);
} }
} }
@ -543,21 +548,20 @@ impl Lua {
/// Requires `feature = "lua54/lua53/lua52"` /// Requires `feature = "lua54/lua53/lua52"`
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", doc))] #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", doc))]
pub fn gc_is_running(&self) -> bool { pub fn gc_is_running(&self) -> bool {
unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCISRUNNING, 0) != 0 } let state = self.main_state.unwrap_or(self.state);
unsafe { ffi::lua_gc(state, ffi::LUA_GCISRUNNING, 0) != 0 }
} }
/// Stop the Lua GC from running /// Stop the Lua GC from running
pub fn gc_stop(&self) { pub fn gc_stop(&self) {
unsafe { let state = self.main_state.unwrap_or(self.state);
ffi::lua_gc(self.main_state, ffi::LUA_GCSTOP, 0); unsafe { ffi::lua_gc(state, ffi::LUA_GCSTOP, 0) };
}
} }
/// Restarts the Lua GC if it is not running /// Restarts the Lua GC if it is not running
pub fn gc_restart(&self) { pub fn gc_restart(&self) {
unsafe { let state = self.main_state.unwrap_or(self.state);
ffi::lua_gc(self.main_state, ffi::LUA_GCRESTART, 0); unsafe { ffi::lua_gc(state, ffi::LUA_GCRESTART, 0) };
}
} }
/// Perform a full garbage-collection cycle. /// Perform a full garbage-collection cycle.
@ -565,8 +569,9 @@ impl Lua {
/// It may be necessary to call this function twice to collect all currently unreachable /// It may be necessary to call this function twice to collect all currently unreachable
/// objects. Once to finish the current gc cycle, and once to start and finish the next cycle. /// objects. Once to finish the current gc cycle, and once to start and finish the next cycle.
pub fn gc_collect(&self) -> Result<()> { pub fn gc_collect(&self) -> Result<()> {
let state = self.main_state.unwrap_or(self.state);
unsafe { unsafe {
protect_lua_closure(self.main_state, 0, 0, |state| { protect_lua_closure(state, 0, 0, |state| {
ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0); ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0);
}) })
} }
@ -584,8 +589,9 @@ impl Lua {
/// if `kbytes` is 0, then this is the same as calling `gc_step`. Returns true if this step has /// if `kbytes` is 0, then this is the same as calling `gc_step`. Returns true if this step has
/// finished a collection cycle. /// finished a collection cycle.
pub fn gc_step_kbytes(&self, kbytes: c_int) -> Result<bool> { pub fn gc_step_kbytes(&self, kbytes: c_int) -> Result<bool> {
let state = self.main_state.unwrap_or(self.state);
unsafe { unsafe {
protect_lua_closure(self.main_state, 0, 0, |state| { protect_lua_closure(state, 0, 0, |state| {
ffi::lua_gc(state, ffi::LUA_GCSTEP, kbytes) != 0 ffi::lua_gc(state, ffi::LUA_GCSTEP, kbytes) != 0
}) })
} }
@ -598,7 +604,8 @@ impl Lua {
/// ///
/// [lua_doc]: https://www.lua.org/manual/5.3/manual.html#2.5 /// [lua_doc]: https://www.lua.org/manual/5.3/manual.html#2.5
pub fn gc_set_pause(&self, pause: c_int) -> c_int { pub fn gc_set_pause(&self, pause: c_int) -> c_int {
unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCSETPAUSE, pause) } let state = self.main_state.unwrap_or(self.state);
unsafe { ffi::lua_gc(state, ffi::LUA_GCSETPAUSE, pause) }
} }
/// Sets the 'step multiplier' value of the collector. /// Sets the 'step multiplier' value of the collector.
@ -608,7 +615,8 @@ impl Lua {
/// ///
/// [lua_doc]: https://www.lua.org/manual/5.3/manual.html#2.5 /// [lua_doc]: https://www.lua.org/manual/5.3/manual.html#2.5
pub fn gc_set_step_multiplier(&self, step_multiplier: c_int) -> c_int { pub fn gc_set_step_multiplier(&self, step_multiplier: c_int) -> c_int {
unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCSETSTEPMUL, step_multiplier) } let state = self.main_state.unwrap_or(self.state);
unsafe { ffi::lua_gc(state, ffi::LUA_GCSETSTEPMUL, step_multiplier) }
} }
/// Changes the collector to incremental mode with the given parameters. /// Changes the collector to incremental mode with the given parameters.
@ -618,6 +626,8 @@ impl Lua {
/// ///
/// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#2.5.1 /// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#2.5.1
pub fn gc_inc(&self, pause: c_int, step_multiplier: c_int, step_size: c_int) -> GCMode { pub fn gc_inc(&self, pause: c_int, step_multiplier: c_int, step_size: c_int) -> GCMode {
let state = self.main_state.unwrap_or(self.state);
#[cfg(any( #[cfg(any(
feature = "lua53", feature = "lua53",
feature = "lua52", feature = "lua52",
@ -626,10 +636,10 @@ impl Lua {
))] ))]
{ {
if pause > 0 { if pause > 0 {
unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCSETPAUSE, pause) }; unsafe { ffi::lua_gc(state, ffi::LUA_GCSETPAUSE, pause) };
} }
if step_multiplier > 0 { if step_multiplier > 0 {
unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCSETSTEPMUL, step_multiplier) }; unsafe { ffi::lua_gc(state, ffi::LUA_GCSETSTEPMUL, step_multiplier) };
} }
let _ = step_size; // Ignored let _ = step_size; // Ignored
GCMode::Incremental GCMode::Incremental
@ -638,7 +648,7 @@ impl Lua {
#[cfg(feature = "lua54")] #[cfg(feature = "lua54")]
let prev_mode = unsafe { let prev_mode = unsafe {
ffi::lua_gc( ffi::lua_gc(
self.main_state, state,
ffi::LUA_GCSETPAUSE, ffi::LUA_GCSETPAUSE,
pause, pause,
step_multiplier, step_multiplier,
@ -663,14 +673,9 @@ impl Lua {
/// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#2.5.2 /// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#2.5.2
#[cfg(any(feature = "lua54", doc))] #[cfg(any(feature = "lua54", doc))]
pub fn gc_gen(&self, minor_multiplier: c_int, major_multiplier: c_int) -> GCMode { pub fn gc_gen(&self, minor_multiplier: c_int, major_multiplier: c_int) -> GCMode {
let prev_mode = unsafe { let state = self.main_state.unwrap_or(self.state);
ffi::lua_gc( let prev_mode =
self.main_state, unsafe { ffi::lua_gc(state, ffi::LUA_GCGEN, minor_multiplier, major_multiplier) };
ffi::LUA_GCGEN,
minor_multiplier,
major_multiplier,
)
};
match prev_mode { match prev_mode {
ffi::LUA_GCGEN => GCMode::Generational, ffi::LUA_GCGEN => GCMode::Generational,
ffi::LUA_GCINC => GCMode::Incremental, ffi::LUA_GCINC => GCMode::Incremental,
@ -1421,7 +1426,7 @@ impl Lua {
// Pushes a LuaRef value onto the stack, uses 1 stack space, does not call checkstack // Pushes a LuaRef value onto the stack, uses 1 stack space, does not call checkstack
pub(crate) unsafe fn push_ref<'lua>(&'lua self, lref: &LuaRef<'lua>) { pub(crate) unsafe fn push_ref<'lua>(&'lua self, lref: &LuaRef<'lua>) {
assert!( assert!(
lref.lua.main_state == self.main_state, Arc::ptr_eq(&lref.lua.extra, &self.extra),
"Lua instance passed Value created from a different main Lua state" "Lua instance passed Value created from a different main Lua state"
); );
let extra = mlua_expect!(self.extra.lock(), "extra is poisoned"); let extra = mlua_expect!(self.extra.lock(), "extra is poisoned");

View File

@ -485,28 +485,25 @@ pub unsafe extern "C" fn error_traceback(state: *mut ffi::lua_State) -> c_int {
} }
// Does not call lua_checkstack, uses 1 stack space. // Does not call lua_checkstack, uses 1 stack space.
pub unsafe fn get_main_state(state: *mut ffi::lua_State) -> *mut ffi::lua_State { pub unsafe fn get_main_state(state: *mut ffi::lua_State) -> Option<*mut ffi::lua_State> {
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
{ {
ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_MAINTHREAD); ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_MAINTHREAD);
let main_state = ffi::lua_tothread(state, -1); let main_state = ffi::lua_tothread(state, -1);
ffi::lua_pop(state, 1); ffi::lua_pop(state, 1);
main_state Some(main_state)
} }
#[cfg(feature = "lua51")] #[cfg(any(feature = "lua51", feature = "luajit"))]
{ {
// Check the current state first // Check the current state first
let is_main_state = ffi::lua_pushthread(state) == 1; let is_main_state = ffi::lua_pushthread(state) == 1;
ffi::lua_pop(state, 1); ffi::lua_pop(state, 1);
if is_main_state { if is_main_state {
state Some(state)
} else { } else {
// The function below is a dirty hack and uses Lua private internals None
ffi::lua_getmainstate(state)
} }
} }
#[cfg(feature = "luajit")]
state
} }
// Pushes a WrappedError to the top of the stack. Uses two stack spaces and does not call // Pushes a WrappedError to the top of the stack. Uses two stack spaces and does not call

View File

@ -27,7 +27,7 @@ fn line_counts() -> Result<()> {
hook_output.lock().unwrap().push(debug.curr_line()); hook_output.lock().unwrap().push(debug.curr_line());
Ok(()) Ok(())
}, },
); )?;
lua.load( lua.load(
r#" r#"
local x = 2 + 3 local x = 2 + 3
@ -62,7 +62,7 @@ fn function_calls() -> Result<()> {
hook_output.lock().unwrap().push((name, what)); hook_output.lock().unwrap().push((name, what));
Ok(()) Ok(())
}, },
); )?;
lua.load( lua.load(
r#" r#"
@ -84,7 +84,7 @@ fn function_calls() -> Result<()> {
} }
#[test] #[test]
fn error_within_hook() { fn error_within_hook() -> Result<()> {
let lua = Lua::new(); let lua = Lua::new();
lua.set_hook( lua.set_hook(
HookTriggers { HookTriggers {
@ -96,7 +96,7 @@ fn error_within_hook() {
"Something happened in there!".to_string(), "Something happened in there!".to_string(),
)) ))
}, },
); )?;
let err = lua let err = lua
.load("x = 1") .load("x = 1")
@ -110,10 +110,12 @@ fn error_within_hook() {
}, },
_ => panic!("wrong error kind caught"), _ => panic!("wrong error kind caught"),
}; };
Ok(())
} }
#[test] #[test]
fn limit_execution_instructions() { fn limit_execution_instructions() -> Result<()> {
let lua = Lua::new(); let lua = Lua::new();
let mut max_instructions = 10000; let mut max_instructions = 10000;
@ -130,9 +132,9 @@ fn limit_execution_instructions() {
Ok(()) Ok(())
} }
}, },
); )?;
lua.globals().set("x", Value::Integer(0)).unwrap(); lua.globals().set("x", Value::Integer(0))?;
let _ = lua let _ = lua
.load( .load(
r#" r#"
@ -143,10 +145,12 @@ fn limit_execution_instructions() {
) )
.exec() .exec()
.expect_err("instruction limit didn't occur"); .expect_err("instruction limit didn't occur");
Ok(())
} }
#[test] #[test]
fn hook_removal() { fn hook_removal() -> Result<()> {
let lua = Lua::new(); let lua = Lua::new();
lua.set_hook( lua.set_hook(
@ -159,15 +163,17 @@ fn hook_removal() {
"this hook should've been removed by this time".to_string(), "this hook should've been removed by this time".to_string(),
)) ))
}, },
); )?;
assert!(lua.load("local x = 1").exec().is_err()); assert!(lua.load("local x = 1").exec().is_err());
lua.remove_hook(); lua.remove_hook();
assert!(lua.load("local x = 1").exec().is_ok()); assert!(lua.load("local x = 1").exec().is_ok());
Ok(())
} }
#[test] #[test]
fn hook_swap_within_hook() { fn hook_swap_within_hook() -> Result<()> {
thread_local! { thread_local! {
static TL_LUA: RefCell<Option<Lua>> = RefCell::new(None); static TL_LUA: RefCell<Option<Lua>> = RefCell::new(None);
} }
@ -183,7 +189,7 @@ fn hook_swap_within_hook() {
..Default::default() ..Default::default()
}, },
move |lua, _debug| { move |lua, _debug| {
lua.globals().set("ok", 1i64).unwrap(); lua.globals().set("ok", 1i64)?;
TL_LUA.with(|tl| { TL_LUA.with(|tl| {
tl.borrow().as_ref().unwrap().set_hook( tl.borrow().as_ref().unwrap().set_hook(
HookTriggers { HookTriggers {
@ -205,26 +211,24 @@ fn hook_swap_within_hook() {
}); });
Ok(()) Ok(())
}, },
); )
}); })
Ok(())
}, },
); )
}); })?;
TL_LUA.with(|tl| { TL_LUA.with(|tl| {
let tl = tl.borrow(); let tl = tl.borrow();
let lua = tl.as_ref().unwrap(); let lua = tl.as_ref().unwrap();
assert!(lua lua.load(
.load(
r#" r#"
local x = 1 local x = 1
x = 2 x = 2
local y = 3 local y = 3
"#, "#,
) )
.exec() .exec()?;
.is_ok()); assert_eq!(lua.globals().get::<_, i64>("ok")?, 2);
assert_eq!(lua.globals().get::<_, i64>("ok").unwrap_or(-1), 2); Ok(())
}); })
} }