From 84fe5f7f761e5a9669ae00df3f6e48ef2814272c Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Wed, 7 Jul 2021 12:54:19 +0100 Subject: [PATCH] Make `protect_lua` as a smart macro to choose from C/closure --- src/function.rs | 6 +++--- src/lua.rs | 46 ++++++++++++++++++++--------------------- src/macros.rs | 15 ++++++++++++++ src/scope.rs | 6 +++--- src/serde/mod.rs | 4 ++-- src/serde/ser.rs | 6 +++--- src/table.rs | 23 ++++++++++----------- src/thread.rs | 4 ++-- src/util.rs | 53 +++++++++++++++++++++++++++++++++++++----------- 9 files changed, 102 insertions(+), 61 deletions(-) diff --git a/src/function.rs b/src/function.rs index af4672f..7b8ef02 100644 --- a/src/function.rs +++ b/src/function.rs @@ -5,7 +5,7 @@ use std::slice; use crate::error::{Error, Result}; use crate::ffi; use crate::types::LuaRef; -use crate::util::{assert_stack, check_stack, error_traceback, pop_error, protect_lua, StackGuard}; +use crate::util::{assert_stack, check_stack, error_traceback, pop_error, StackGuard}; use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti}; #[cfg(feature = "async")] @@ -198,8 +198,8 @@ impl<'lua> Function<'lua> { for arg in args { lua.push_value(arg)?; } - protect_lua(lua.state, nargs + 2, 1, |state| { - ffi::lua_pushcclosure(state, bind_call_impl, nargs + 2); + protect_lua!(lua.state, nargs + 2, 1, state => { + ffi::lua_pushcclosure(state, bind_call_impl, ffi::lua_gettop(state)); })?; Ok(Function(lua.pop_ref())) diff --git a/src/lua.rs b/src/lua.rs index b81e442..a0c094e 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -28,7 +28,7 @@ use crate::userdata::{ use crate::util::{ self, assert_stack, callback_error, check_stack, get_destructed_userdata_metatable, get_gc_metatable_for, get_gc_userdata, get_main_state, get_userdata, get_wrapped_error, - init_error_registry, init_gc_metatable_for, init_userdata_metatable, pop_error, protect_lua, + init_error_registry, init_gc_metatable_for, init_userdata_metatable, pop_error, push_gc_userdata, push_string, push_table, push_userdata, push_wrapped_error, rawset_field, safe_pcall, safe_xpcall, StackGuard, WrappedError, WrappedPanic, }; @@ -408,8 +408,8 @@ impl Lua { // Create empty Waker slot push_gc_userdata::>(state, None)?; - let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void; - protect_lua(state, 1, 0, |state| { + protect_lua!(state, 1, 0, state => { + let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void; ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, waker_key); })?; } @@ -420,7 +420,7 @@ impl Lua { // Create ref stack thread and place it in the registry to prevent it from being garbage // collected. - let ref_thread = protect_lua(state, 0, 0, |state| { + let ref_thread = protect_lua!(state, 0, 0, |state| { let thread = ffi::lua_newthread(state); ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); thread @@ -453,9 +453,9 @@ impl Lua { push_gc_userdata(main_state, Arc::downgrade(&extra)), "Error while storing extra data", ); - let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; mlua_expect!( - protect_lua(main_state, 1, 0, |state| { + protect_lua!(main_state, 1, 0, state => { + let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, extra_key) }), "Error while storing extra data", @@ -535,7 +535,7 @@ impl Lua { let _sg = StackGuard::new(self.state); check_stack(self.state, 3)?; - protect_lua(self.state, 0, 1, |state| { + protect_lua!(self.state, 0, 1, state => { ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); })?; let loaded = Table(self.pop_ref()); @@ -735,9 +735,7 @@ impl Lua { let state = self.main_state.unwrap_or(self.state); unsafe { check_stack(state, 3)?; - protect_lua(state, 0, 0, |state| { - ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0); - }) + protect_lua!(state, 0, 0, state => ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0)) } } @@ -756,7 +754,7 @@ impl Lua { let state = self.main_state.unwrap_or(self.state); unsafe { check_stack(state, 3)?; - protect_lua(state, 0, 0, |state| { + protect_lua!(state, 0, 0, |state| { ffi::lua_gc(state, ffi::LUA_GCSTEP, kbytes) != 0 }) } @@ -933,7 +931,7 @@ impl Lua { unsafe { let _sg = StackGuard::new(self.state); check_stack(self.state, 3)?; - push_table(self.state, 0, 0)?; + protect_lua!(self.state, 0, 1, state => ffi::lua_newtable(state))?; Ok(Table(self.pop_ref())) } } @@ -968,7 +966,7 @@ impl Lua { for (k, v) in iter { self.push_value(k.to_lua(self)?)?; self.push_value(v.to_lua(self)?)?; - protect_lua(self.state, 3, 1, |state| ffi::lua_rawset(state, -3))?; + protect_lua!(self.state, 3, 1, state => ffi::lua_rawset(state, -3))?; } Ok(Table(self.pop_ref())) @@ -990,7 +988,7 @@ impl Lua { push_table(self.state, lower_bound as c_int, 0)?; for (i, v) in iter.enumerate() { self.push_value(v.to_lua(self)?)?; - protect_lua(self.state, 2, 1, |state| { + protect_lua!(self.state, 2, 1, |state| { ffi::lua_rawseti(state, -2, (i + 1) as Integer); })?; } @@ -1087,7 +1085,7 @@ impl Lua { pub unsafe fn create_c_function(&self, func: ffi::lua_CFunction) -> Result { let _sg = StackGuard::new(self.state); check_stack(self.state, 3)?; - protect_lua(self.state, 0, 1, |state| { + protect_lua!(self.state, 0, 1, |state| { ffi::lua_pushcfunction(state, func); })?; Ok(Function(self.pop_ref())) @@ -1163,7 +1161,7 @@ impl Lua { let _sg = StackGuard::new(self.state); check_stack(self.state, 3)?; - let thread_state = protect_lua(self.state, 0, 1, |state| ffi::lua_newthread(state))?; + let thread_state = protect_lua!(self.state, 0, 1, |state| ffi::lua_newthread(state))?; self.push_ref(&func.0); ffi::lua_xmove(self.state, thread_state, 1); @@ -1278,7 +1276,7 @@ impl Lua { check_stack(self.state, 4)?; self.push_value(v)?; - let res = protect_lua(self.state, 1, 1, |state| { + let res = protect_lua!(self.state, 1, 1, |state| { ffi::lua_tolstring(state, -1, ptr::null_mut()) })?; if !res.is_null() { @@ -1432,7 +1430,7 @@ impl Lua { check_stack(self.state, 4)?; self.push_value(t)?; - let registry_id = protect_lua(self.state, 1, 0, |state| { + let registry_id = protect_lua!(self.state, 1, 0, |state| { ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) })?; @@ -1708,7 +1706,7 @@ impl Lua { rawset_field(self.state, -2, k.validate()?.name())?; } // Add special `__mlua_type_id` field - let type_id_ptr = protect_lua(self.state, 0, 1, |state| { + let type_id_ptr = protect_lua!(self.state, 0, 1, |state| { ffi::lua_newuserdata(state, mem::size_of::()) as *mut TypeId })?; ptr::write(type_id_ptr, type_id); @@ -1774,7 +1772,7 @@ impl Lua { let ptr = ffi::lua_topointer(self.state, -1); ffi::lua_pushvalue(self.state, -1); - let id = protect_lua(self.state, 1, 0, |state| { + let id = protect_lua!(self.state, 1, 0, |state| { ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) })?; @@ -1881,7 +1879,7 @@ impl Lua { let lua = self.clone(); let func = mem::transmute(func); push_gc_userdata(self.state, CallbackUpvalue { lua, func })?; - protect_lua(self.state, 1, 1, |state| { + protect_lua!(self.state, 1, 1, state => { ffi::lua_pushcclosure(state, call_callback, 1); })?; @@ -1933,7 +1931,7 @@ impl Lua { let fut = ((*upvalue).func)(lua, args); let lua = lua.clone(); push_gc_userdata(state, AsyncPollUpvalue { lua, fut })?; - protect_lua(state, 1, 1, |state| { + protect_lua!(state, 1, 1, state => { ffi::lua_pushcclosure(state, poll_future, 1); })?; @@ -1999,7 +1997,7 @@ impl Lua { let lua = self.clone(); let func = mem::transmute(func); push_gc_userdata(self.state, AsyncCallbackUpvalue { lua, func })?; - protect_lua(self.state, 1, 1, |state| { + protect_lua!(self.state, 1, 1, state => { ffi::lua_pushcclosure(state, call_callback, 1); })?; @@ -2476,7 +2474,7 @@ unsafe fn load_from_std_lib(state: *mut ffi::lua_State, libs: StdLib) -> Result< glb: c_int, ) -> Result<()> { let modname = mlua_expect!(CString::new(modname.as_ref()), "modname contains nil bytes"); - protect_lua(state, 0, 1, |state| { + protect_lua!(state, 0, 1, |state| { ffi::luaL_requiref(state, modname.as_ptr() as *const c_char, openf, glb) }) } diff --git a/src/macros.rs b/src/macros.rs index 89f54da..798983a 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -94,3 +94,18 @@ macro_rules! require_module_feature { compile_error!("Feature `module` must be enabled in the `mlua` crate"); }; } + +macro_rules! protect_lua { + ($state:expr, $nargs:expr, $nresults:expr, $f:expr) => { + crate::util::protect_lua_closure($state, $nargs, $nresults, $f) + }; + + ($state:expr, $nargs:expr, $nresults:expr, $state_inner:ident => $code:expr) => {{ + unsafe extern "C" fn do_call($state_inner: *mut ffi::lua_State) -> ::std::os::raw::c_int { + $code; + $nresults + } + + crate::util::protect_lua_c($state, $nargs, do_call) + }}; +} diff --git a/src/scope.rs b/src/scope.rs index 985db3e..0a1ef77 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -18,8 +18,8 @@ use crate::userdata::{ AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods, }; use crate::util::{ - assert_stack, check_stack, get_userdata, init_userdata_metatable, protect_lua, push_table, - rawset_field, take_userdata, StackGuard, + assert_stack, check_stack, get_userdata, init_userdata_metatable, push_table, rawset_field, + take_userdata, StackGuard, }; use crate::value::{FromLua, FromLuaMulti, MultiValue, ToLua, ToLuaMulti, Value}; @@ -322,7 +322,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 13)?; - let data_ptr = protect_lua(lua.state, 0, 1, |state| { + let data_ptr = protect_lua!(lua.state, 0, 1, |state| { ffi::lua_newuserdata(state, mem::size_of::>>>()) })?; // Prepare metatable, add meta methods first and then meta fields diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 27ce492..1252fbf 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -10,7 +10,7 @@ use crate::ffi; use crate::lua::Lua; use crate::table::Table; use crate::types::LightUserData; -use crate::util::{assert_stack, check_stack, protect_lua, StackGuard}; +use crate::util::{assert_stack, check_stack, StackGuard}; use crate::value::Value; /// Trait for serializing/deserializing Lua values using Serde. @@ -201,7 +201,7 @@ impl<'lua> LuaSerdeExt<'lua> for Lua { // Uses 6 stack spaces and calls checkstack. pub(crate) unsafe fn init_metatables(state: *mut ffi::lua_State) -> Result<()> { check_stack(state, 3)?; - protect_lua(state, 0, 0, |state| { + protect_lua!(state, 0, 0, state => { ffi::lua_createtable(state, 0, 1); ffi::lua_pushstring(state, cstr!("__metatable")); diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 4c6423c..aa06a58 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -9,7 +9,7 @@ use crate::lua::Lua; use crate::string::String; use crate::table::Table; use crate::types::Integer; -use crate::util::{check_stack, protect_lua, StackGuard}; +use crate::util::{check_stack, StackGuard}; use crate::value::{ToLua, Value}; /// A struct for serializing Rust values into Lua values. @@ -322,8 +322,8 @@ impl<'lua> ser::SerializeSeq for SerializeVec<'lua> { lua.push_ref(&self.table.0); lua.push_value(value)?; - let len = ffi::lua_rawlen(lua.state, -2) as Integer; - protect_lua(lua.state, 2, 0, |state| { + protect_lua!(lua.state, 2, 0, state => { + let len = ffi::lua_rawlen(state, -2) as Integer; ffi::lua_rawseti(state, -2, len + 1); }) } diff --git a/src/table.rs b/src/table.rs index af9aaaa..cd3985b 100644 --- a/src/table.rs +++ b/src/table.rs @@ -10,7 +10,7 @@ use crate::error::{Error, Result}; use crate::ffi; use crate::function::Function; use crate::types::{Integer, LuaRef}; -use crate::util::{assert_stack, check_stack, protect_lua, StackGuard}; +use crate::util::{assert_stack, check_stack, StackGuard}; use crate::value::{FromLua, FromLuaMulti, Nil, ToLua, ToLuaMulti, Value}; #[cfg(feature = "async")] @@ -67,7 +67,7 @@ impl<'lua> Table<'lua> { lua.push_ref(&self.0); lua.push_value(key)?; lua.push_value(value)?; - protect_lua(lua.state, 3, 0, |state| ffi::lua_settable(state, -3)) + protect_lua!(lua.state, 3, 0, state => ffi::lua_settable(state, -3)) } } @@ -105,7 +105,7 @@ impl<'lua> Table<'lua> { lua.push_ref(&self.0); lua.push_value(key)?; - protect_lua(lua.state, 2, 1, |state| ffi::lua_gettable(state, -2))?; + protect_lua!(lua.state, 2, 1, state => ffi::lua_gettable(state, -2))?; lua.pop_value() }; @@ -123,9 +123,8 @@ impl<'lua> Table<'lua> { lua.push_ref(&self.0); lua.push_value(key)?; - protect_lua(lua.state, 2, 1, |state| { - ffi::lua_gettable(state, -2) != ffi::LUA_TNIL - }) + protect_lua!(lua.state, 2, 1, state => ffi::lua_gettable(state, -2))?; + Ok(ffi::lua_isnil(lua.state, -1) == 0) } } @@ -198,7 +197,7 @@ impl<'lua> Table<'lua> { lua.push_ref(&self.0); lua.push_value(key)?; lua.push_value(value)?; - protect_lua(lua.state, 3, 0, |state| ffi::lua_rawset(state, -3)) + protect_lua!(lua.state, 3, 0, state => ffi::lua_rawset(state, -3)) } } @@ -236,7 +235,7 @@ impl<'lua> Table<'lua> { lua.push_ref(&self.0); lua.push_value(value)?; - protect_lua(lua.state, 2, 0, |state| { + protect_lua!(lua.state, 2, 0, |state| { for i in (idx..=size).rev() { // table[i+1] = table[i] ffi::lua_rawgeti(state, -2, i); @@ -268,7 +267,7 @@ impl<'lua> Table<'lua> { check_stack(lua.state, 4)?; lua.push_ref(&self.0); - protect_lua(lua.state, 1, 0, |state| { + protect_lua!(lua.state, 1, 0, |state| { for i in idx..size { ffi::lua_rawgeti(state, -1, i + 1); ffi::lua_rawseti(state, -2, i); @@ -294,7 +293,7 @@ impl<'lua> Table<'lua> { check_stack(lua.state, 4)?; lua.push_ref(&self.0); - protect_lua(lua.state, 1, 0, |state| ffi::luaL_len(state, -1)) + protect_lua!(lua.state, 1, 0, |state| ffi::luaL_len(state, -1)) } } @@ -671,7 +670,7 @@ where lua.push_ref(&self.table); lua.push_value(prev_key)?; - let next = protect_lua(lua.state, 2, ffi::LUA_MULTRET, |state| { + let next = protect_lua!(lua.state, 2, ffi::LUA_MULTRET, |state| { ffi::lua_next(state, -2) })?; if next != 0 { @@ -732,7 +731,7 @@ where let res = if self.raw { ffi::lua_rawgeti(lua.state, -1, index) } else { - protect_lua(lua.state, 1, 1, |state| ffi::lua_geti(state, -1, index))? + protect_lua!(lua.state, 1, 1, |state| ffi::lua_geti(state, -1, index))? }; match res { ffi::LUA_TNIL if index > self.len.unwrap_or(0) => Ok(None), diff --git a/src/thread.rs b/src/thread.rs index 632c341..403eb78 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -4,7 +4,7 @@ use std::os::raw::c_int; use crate::error::{Error, Result}; use crate::ffi; use crate::types::LuaRef; -use crate::util::{assert_stack, check_stack, error_traceback, pop_error, protect_lua, StackGuard}; +use crate::util::{assert_stack, check_stack, error_traceback, pop_error, StackGuard}; use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti}; #[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored"), doc))] @@ -135,7 +135,7 @@ impl<'lua> Thread<'lua> { let ret = ffi::lua_resume(thread_state, lua.state, nargs, &mut nresults as *mut c_int); if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { - protect_lua(lua.state, 0, 0, |_| error_traceback(thread_state))?; + protect_lua!(lua.state, 0, 0, |_| error_traceback(thread_state))?; return Err(pop_error(thread_state, ret)); } diff --git a/src/util.rs b/src/util.rs index 8eab90f..b55b109 100644 --- a/src/util.rs +++ b/src/util.rs @@ -83,6 +83,35 @@ impl Drop for StackGuard { } } +// Call a function that calls into the Lua API and may trigger a Lua error (longjmp) in a safe way. +// Wraps the inner function in a call to `lua_pcall`, so the inner function only has access to a +// limited lua stack. `nargs` is the same as the the parameter to `lua_pcall`, and `nresults` is +// always `LUA_MULTRET`. Provided function must *not* panic, and since it will generally be lonjmping, +// should not contain any values that implements Drop. +// Internally uses 2 extra stack spaces, and does not call checkstack. +pub unsafe fn protect_lua_c( + state: *mut ffi::lua_State, + nargs: c_int, + f: unsafe extern "C" fn(*mut ffi::lua_State) -> c_int, +) -> Result<()> { + let stack_start = ffi::lua_gettop(state) - nargs; + + ffi::lua_pushcfunction(state, error_traceback); + ffi::lua_pushcfunction(state, f); + if nargs > 0 { + ffi::lua_rotate(state, stack_start + 1, 2); + } + + let ret = ffi::lua_pcall(state, nargs, ffi::LUA_MULTRET, stack_start + 1); + ffi::lua_remove(state, stack_start + 1); + + if ret == ffi::LUA_OK { + Ok(()) + } else { + Err(pop_error(state, ret)) + } +} + // Call a function that calls into the Lua API and may trigger a Lua error (longjmp) in a safe way. // Wraps the inner function in a call to `lua_pcall`, so the inner function only has access to a // limited lua stack. `nargs` and `nresults` are similar to the parameters of `lua_pcall`, but the @@ -90,7 +119,7 @@ impl Drop for StackGuard { // values are assumed to match the `nresults` param. Provided function must *not* panic, and since it // will generally be lonjmping, should not contain any values that implements Drop. // Internally uses 3 extra stack spaces, and does not call checkstack. -pub unsafe fn protect_lua( +pub unsafe fn protect_lua_closure( state: *mut ffi::lua_State, nargs: c_int, nresults: c_int, @@ -212,14 +241,14 @@ pub unsafe fn push_string + ?Sized>( s: &S, ) -> Result<()> { let s = s.as_ref(); - protect_lua(state, 0, 1, |state| { + protect_lua!(state, 0, 1, |state| { ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len()); }) } // Uses 3 stack spaces pub unsafe fn push_table(state: *mut ffi::lua_State, narr: c_int, nrec: c_int) -> Result<()> { - protect_lua(state, 0, 1, |state| ffi::lua_createtable(state, narr, nrec)) + protect_lua!(state, 0, 1, |state| ffi::lua_createtable(state, narr, nrec)) } // Uses 4 stack spaces @@ -229,7 +258,7 @@ where { let field = field.as_ref(); ffi::lua_pushvalue(state, table); - protect_lua(state, 2, 0, |state| { + protect_lua!(state, 2, 0, |state| { ffi::lua_pushlstring(state, field.as_ptr() as *const c_char, field.len()); ffi::lua_rotate(state, -3, 2); ffi::lua_rawset(state, -3) @@ -238,7 +267,7 @@ where // Internally uses 3 stack spaces, does not call checkstack. pub unsafe fn push_userdata(state: *mut ffi::lua_State, t: T) -> Result<()> { - let ud = protect_lua(state, 0, 1, |state| { + let ud = protect_lua!(state, 0, 1, |state| { ffi::lua_newuserdata(state, mem::size_of::()) as *mut T })?; ptr::write(ud, t); @@ -408,7 +437,7 @@ pub unsafe fn init_userdata_metatable( ffi::lua_pushnil(state); } } - protect_lua(state, 3, 1, |state| { + protect_lua!(state, 3, 1, state => { ffi::lua_pushcclosure(state, meta_index_impl, 3) })?; } @@ -424,7 +453,7 @@ pub unsafe fn init_userdata_metatable( match newindex_type { ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { ffi::lua_pushvalue(state, field_setters); - protect_lua(state, 2, 1, |state| { + protect_lua!(state, 2, 1, state => { ffi::lua_pushcclosure(state, meta_newindex_impl, 2) })?; } @@ -680,7 +709,7 @@ pub unsafe fn init_gc_metatable_for( f(state)?; } - protect_lua(state, 1, 0, |state| { + protect_lua!(state, 1, 0, |state| { ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, ref_addr as *mut c_void) })?; @@ -840,16 +869,16 @@ pub unsafe fn init_error_registry(state: *mut ffi::lua_State) -> Result<()> { } ffi::lua_pop(state, 1); - let destructed_metatable_key = &DESTRUCTED_USERDATA_METATABLE as *const u8 as *const c_void; - protect_lua(state, 1, 0, |state| { + protect_lua!(state, 1, 0, state => { + let destructed_metatable_key = &DESTRUCTED_USERDATA_METATABLE as *const u8 as *const c_void; ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, destructed_metatable_key) })?; // Create error print buffer init_gc_metatable_for::(state, None)?; push_gc_userdata(state, String::new())?; - let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; - protect_lua(state, 1, 0, |state| { + protect_lua!(state, 1, 0, state => { + let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, err_buf_key) })?;