From d06890afc6cf251042a9b2a3e0db0988f25d4f9d Mon Sep 17 00:00:00 2001 From: kyren Date: Thu, 8 Mar 2018 11:40:24 -0500 Subject: [PATCH] Simplify stack_guard / stack_err_guard The expected change is always zero, because stack_guard / stack_err_guard are always used at `rlua` entry / exit points. --- src/conversion.rs | 6 ++--- src/function.rs | 4 ++-- src/lua.rs | 42 +++++++++++++++++------------------ src/string.rs | 2 +- src/table.rs | 22 +++++++++---------- src/tests/thread.rs | 1 - src/thread.rs | 4 ++-- src/userdata.rs | 6 ++--- src/util.rs | 53 ++++++++++++++++++++------------------------- 9 files changed, 65 insertions(+), 75 deletions(-) diff --git a/src/conversion.rs b/src/conversion.rs index 5d09f5d..a6749f5 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -270,16 +270,14 @@ impl<'lua, T: FromLua<'lua>> FromLua<'lua> for Vec { } impl<'lua, K: Eq + Hash + ToLua<'lua>, V: ToLua<'lua>, S: BuildHasher> ToLua<'lua> - for HashMap -{ + for HashMap { fn to_lua(self, lua: &'lua Lua) -> Result> { Ok(Value::Table(lua.create_table_from(self)?)) } } impl<'lua, K: Eq + Hash + FromLua<'lua>, V: FromLua<'lua>, S: BuildHasher + Default> FromLua<'lua> - for HashMap -{ + for HashMap { fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { if let Value::Table(table) = value { table.pairs().collect() diff --git a/src/function.rs b/src/function.rs index 3cc9e1d..246e852 100644 --- a/src/function.rs +++ b/src/function.rs @@ -63,7 +63,7 @@ impl<'lua> Function<'lua> { pub fn call, R: FromLuaMulti<'lua>>(&self, args: A) -> Result { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; check_stack_err(lua.state, nargs + 3)?; @@ -144,7 +144,7 @@ impl<'lua> Function<'lua> { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; diff --git a/src/lua.rs b/src/lua.rs index 3c2f71a..8b8ec3c 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -91,7 +91,7 @@ impl Lua { /// Equivalent to Lua's `load` function. pub fn load(&self, source: &str, name: Option<&str>) -> Result { unsafe { - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 1); match if let Some(name) = name { @@ -156,7 +156,7 @@ impl Lua { /// Pass a `&str` slice to Lua, creating and returning an interned Lua string. pub fn create_string(&self, s: &str) -> Result { unsafe { - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 4); push_string(self.state, s)?; Ok(String(self.pop_ref(self.state))) @@ -167,7 +167,7 @@ impl Lua { /// Creates and returns a new table. pub fn create_table(&self) -> Result { unsafe { - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 4); protect_lua_call(self.state, 0, 1, |state| { ffi::lua_newtable(state); @@ -185,7 +185,7 @@ impl Lua { I: IntoIterator, { unsafe { - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 6); protect_lua_call(self.state, 0, 1, |state| { ffi::lua_newtable(state); @@ -305,7 +305,7 @@ impl Lua { /// Equivalent to `coroutine.create`. pub fn create_thread<'lua>(&'lua self, func: Function<'lua>) -> Result> { unsafe { - stack_err_guard(self.state, 0, move || { + stack_err_guard(self.state, move || { check_stack(self.state, 2); let thread_state = @@ -328,7 +328,7 @@ impl Lua { /// Returns a handle to the global environment. pub fn globals(&self) -> Table { unsafe { - stack_guard(self.state, 0, move || { + stack_guard(self.state, move || { check_stack(self.state, 2); ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); Table(self.pop_ref(self.state)) @@ -374,7 +374,7 @@ impl Lua { match v { Value::String(s) => Ok(s), v => unsafe { - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 4); let ty = v.type_name(); self.push_value(self.state, v); @@ -403,7 +403,7 @@ impl Lua { match v { Value::Integer(i) => Ok(i), v => unsafe { - stack_guard(self.state, 0, || { + stack_guard(self.state, || { check_stack(self.state, 2); let ty = v.type_name(); self.push_value(self.state, v); @@ -432,7 +432,7 @@ impl Lua { match v { Value::Number(n) => Ok(n), v => unsafe { - stack_guard(self.state, 0, || { + stack_guard(self.state, || { check_stack(self.state, 2); let ty = v.type_name(); self.push_value(self.state, v); @@ -486,7 +486,7 @@ impl Lua { t: T, ) -> Result<()> { unsafe { - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 5); push_string(self.state, name)?; @@ -507,7 +507,7 @@ impl Lua { /// [`set_named_registry_value`]: #method.set_named_registry_value pub fn named_registry_value<'lua, T: FromLua<'lua>>(&'lua self, name: &str) -> Result { unsafe { - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 4); push_string(self.state, name)?; @@ -535,7 +535,7 @@ impl Lua { /// state. pub fn create_registry_value<'lua, T: ToLua<'lua>>(&'lua self, t: T) -> Result { unsafe { - stack_guard(self.state, 0, || { + stack_guard(self.state, || { check_stack(self.state, 2); self.push_value(self.state, t.to_lua(self)?); @@ -564,7 +564,7 @@ impl Lua { return Err(Error::MismatchedRegistryKey); } - stack_err_guard(self.state, 0, || { + stack_err_guard(self.state, || { check_stack(self.state, 2); ffi::lua_rawgeti( self.state, @@ -776,7 +776,7 @@ impl Lua { } } - stack_err_guard(self.state, 0, move || { + stack_err_guard(self.state, move || { check_stack(self.state, 5); if let Some(table_id) = (*self.extra()).registered_userdata.get(&TypeId::of::()) { @@ -922,7 +922,7 @@ impl Lua { // Ignores or `unwrap()`s 'm' errors, because this is assuming that nothing in the lua // standard library will have a `__gc` metamethod error. - stack_guard(state, 0, || { + stack_guard(state, || { // Do not open the debug library, it can be used to cause unsafety. ffi::luaL_requiref(state, cstr!("_G"), ffi::luaopen_base, 1); ffi::luaL_requiref(state, cstr!("coroutine"), ffi::luaopen_coroutine, 1); @@ -1030,7 +1030,7 @@ impl Lua { } unsafe { - stack_err_guard(self.state, 0, move || { + stack_err_guard(self.state, move || { check_stack(self.state, 2); push_userdata::(self.state, func)?; @@ -1056,7 +1056,7 @@ impl Lua { T: UserData, { unsafe { - stack_err_guard(self.state, 0, move || { + stack_err_guard(self.state, move || { check_stack(self.state, 3); push_userdata::>(self.state, RefCell::new(data))?; @@ -1104,7 +1104,7 @@ impl<'scope> Scope<'scope> { let mut destructors = self.destructors.borrow_mut(); let registry_id = f.0.registry_id; destructors.push(Box::new(move |state| { - stack_guard(state, 0, || { + stack_guard(state, || { check_stack(state, 2); ffi::lua_rawgeti( @@ -1130,8 +1130,8 @@ impl<'scope> Scope<'scope> { /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. /// - /// This is a version of [`Lua::create_function_mut`] that creates a callback which expires on scope - /// drop. See [`Lua::scope`] for more details. + /// This is a version of [`Lua::create_function_mut`] that creates a callback which expires on + /// scope drop. See [`Lua::scope`] for more details. /// /// [`Lua::create_function_mut`]: struct.Lua.html#method.create_function_mut /// [`Lua::scope`]: struct.Lua.html#method.scope @@ -1168,7 +1168,7 @@ impl<'scope> Scope<'scope> { let mut destructors = self.destructors.borrow_mut(); let registry_id = u.0.registry_id; destructors.push(Box::new(move |state| { - stack_guard(state, 0, || { + stack_guard(state, || { check_stack(state, 1); ffi::lua_rawgeti( state, diff --git a/src/string.rs b/src/string.rs index 98b6699..ba127ef 100644 --- a/src/string.rs +++ b/src/string.rs @@ -69,7 +69,7 @@ impl<'lua> String<'lua> { pub fn as_bytes_with_nul(&self) -> &[u8] { let lua = self.0.lua; unsafe { - stack_guard(lua.state, 0, || { + stack_guard(lua.state, || { check_stack(lua.state, 1); lua.push_ref(lua.state, &self.0); rlua_assert!( diff --git a/src/table.rs b/src/table.rs index adf156d..0b7875c 100644 --- a/src/table.rs +++ b/src/table.rs @@ -51,7 +51,7 @@ impl<'lua> Table<'lua> { pub fn set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 6); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?); @@ -94,7 +94,7 @@ impl<'lua> Table<'lua> { pub fn get, V: FromLua<'lua>>(&self, key: K) -> Result { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 5); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?); @@ -108,7 +108,7 @@ impl<'lua> Table<'lua> { pub fn contains_key>(&self, key: K) -> Result { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 5); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?); @@ -124,7 +124,7 @@ impl<'lua> Table<'lua> { pub fn raw_set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 6); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?); @@ -141,7 +141,7 @@ impl<'lua> Table<'lua> { pub fn raw_get, V: FromLua<'lua>>(&self, key: K) -> Result { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 3); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?); @@ -161,7 +161,7 @@ impl<'lua> Table<'lua> { pub fn len(&self) -> Result { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 4); lua.push_ref(lua.state, &self.0); protect_lua_call(lua.state, 1, 0, |state| ffi::luaL_len(state, -1)) @@ -173,7 +173,7 @@ impl<'lua> Table<'lua> { pub fn raw_len(&self) -> Integer { let lua = self.0.lua; unsafe { - stack_guard(lua.state, 0, || { + stack_guard(lua.state, || { check_stack(lua.state, 1); lua.push_ref(lua.state, &self.0); let len = ffi::lua_rawlen(lua.state, -1); @@ -189,7 +189,7 @@ impl<'lua> Table<'lua> { pub fn get_metatable(&self) -> Option> { let lua = self.0.lua; unsafe { - stack_guard(lua.state, 0, || { + stack_guard(lua.state, || { check_stack(lua.state, 1); lua.push_ref(lua.state, &self.0); if ffi::lua_getmetatable(lua.state, -1) == 0 { @@ -211,7 +211,7 @@ impl<'lua> Table<'lua> { pub fn set_metatable(&self, metatable: Option>) { let lua = self.0.lua; unsafe { - stack_guard(lua.state, 0, move || { + stack_guard(lua.state, move || { check_stack(lua.state, 1); lua.push_ref(lua.state, &self.0); if let Some(metatable) = metatable { @@ -346,7 +346,7 @@ where let lua = self.table.lua; unsafe { - stack_guard(lua.state, 0, || { + stack_guard(lua.state, || { check_stack(lua.state, 6); lua.push_ref(lua.state, &self.table); @@ -408,7 +408,7 @@ where let lua = self.table.lua; unsafe { - stack_guard(lua.state, 0, || { + stack_guard(lua.state, || { check_stack(lua.state, 5); lua.push_ref(lua.state, &self.table); diff --git a/src/tests/thread.rs b/src/tests/thread.rs index 2d483a0..ccfb397 100644 --- a/src/tests/thread.rs +++ b/src/tests/thread.rs @@ -1,4 +1,3 @@ - use std::panic::catch_unwind; use {Error, Function, Lua, Thread, ThreadStatus}; diff --git a/src/thread.rs b/src/thread.rs index ab5b97a..1c8d753 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -78,7 +78,7 @@ impl<'lua> Thread<'lua> { { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 1); lua.push_ref(lua.state, &self.0); @@ -120,7 +120,7 @@ impl<'lua> Thread<'lua> { pub fn status(&self) -> ThreadStatus { let lua = self.0.lua; unsafe { - stack_guard(lua.state, 0, || { + stack_guard(lua.state, || { check_stack(lua.state, 1); lua.push_ref(lua.state, &self.0); diff --git a/src/userdata.rs b/src/userdata.rs index b20561a..d40eb1f 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -415,7 +415,7 @@ impl<'lua> AnyUserData<'lua> { { unsafe { let lua = self.0.lua; - stack_err_guard(lua.state, 0, move || { + stack_err_guard(lua.state, move || { check_stack(lua.state, 3); lua.push_ref(lua.state, &self.0); @@ -451,7 +451,7 @@ impl<'lua> AnyUserData<'lua> { pub fn set_user_value>(&self, v: V) -> Result<()> { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 2); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, v.to_lua(lua)?); @@ -468,7 +468,7 @@ impl<'lua> AnyUserData<'lua> { pub fn get_user_value>(&self) -> Result { let lua = self.0.lua; unsafe { - stack_err_guard(lua.state, 0, || { + stack_err_guard(lua.state, || { check_stack(lua.state, 3); lua.push_ref(lua.state, &self.0); ffi::lua_getuservalue(lua.state, -1); diff --git a/src/util.rs b/src/util.rs index 9295553..ba4dda8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -27,16 +27,13 @@ pub unsafe fn check_stack_err(state: *mut ffi::lua_State, amount: c_int) -> Resu } } -// Run an operation on a lua_State and check that the stack change is what is expected. If the -// stack change does not match, resets the stack and panics. If the given operation panics, tries -// to restore the stack to its previous state before resuming the panic. -pub unsafe fn stack_guard(state: *mut ffi::lua_State, change: c_int, op: F) -> R +// Run an operation on a lua_State and ensure that there are no stack leaks and the stack is +// restored on panic. +pub unsafe fn stack_guard(state: *mut ffi::lua_State, op: F) -> R where F: FnOnce() -> R, { let begin = ffi::lua_gettop(state); - let expected = begin + change; - rlua_assert!(expected >= 0, "too many stack values would be popped"); let res = match catch_unwind(AssertUnwindSafe(op)) { Ok(r) => r, @@ -50,30 +47,26 @@ where }; let top = ffi::lua_gettop(state); - if top != expected { - if top > begin { - ffi::lua_settop(state, begin); - } - rlua_panic!("expected stack to be {}, got {}", expected, top); + if top > begin { + ffi::lua_settop(state, begin); + rlua_panic!("expected stack to be {}, got {}", begin, top); + } else if top < begin { + rlua_abort!("{} too many stack values popped", begin - top); } res } -// Run an operation on a lua_State and automatically clean up the stack before returning. Takes the -// lua_State, the expected stack size change, and an operation to run. If the operation results in -// success, then the stack is inspected to make sure the change in stack size matches the expected -// change and otherwise this is a logic error and will panic. If the operation results in an error, -// the stack is shrunk to the value before the call. If the operation results in an error and the -// stack is smaller than the value before the call, then this is unrecoverable and this will panic. -// If this function panics, it will clear the stack before panicking. -pub unsafe fn stack_err_guard(state: *mut ffi::lua_State, change: c_int, op: F) -> Result +// Run an operation on a lua_State and automatically clean up the stack on error. Takes the +// lua_State and an operation to run. If the operation results in success, then the stack is +// inspected to make sure there is not a stack leak, and otherwise this is a logic error and will +// panic. If the operation results in an error, or if the operation panics, the stack is shrunk to +// the value before the call. +pub unsafe fn stack_err_guard(state: *mut ffi::lua_State, op: F) -> Result where F: FnOnce() -> Result, { let begin = ffi::lua_gettop(state); - let expected = begin + change; - rlua_assert!(expected >= 0, "too many stack values would be popped"); let res = match catch_unwind(AssertUnwindSafe(op)) { Ok(r) => r, @@ -88,17 +81,17 @@ where let top = ffi::lua_gettop(state); if res.is_ok() { - if top != expected { - if top > begin { - ffi::lua_settop(state, begin); - } - rlua_panic!("expected stack to be {}, got {}", expected, top); + if top > begin { + ffi::lua_settop(state, begin); + rlua_panic!("expected stack to be {}, got {}", begin, top); + } else if top < begin { + rlua_abort!("{} too many stack values popped", begin - top); } } else { - if top > expected { - ffi::lua_settop(state, expected); - } else if top < expected { - rlua_panic!("{} too many stack values popped", top - begin - change); + if top > begin { + ffi::lua_settop(state, begin); + } else if top < begin { + rlua_abort!("{} too many stack values popped", begin - top); } } res