From a05f0d5cd025cbdabc13241416de85be2a4f73a4 Mon Sep 17 00:00:00 2001 From: kyren Date: Mon, 19 Mar 2018 15:16:40 -0400 Subject: [PATCH] Where possible, don't call to_lua / from_lua / to_lua_multi / from_lua_multi callbacks during Lua stack manipulation This should protect against being able to trigger a stack assert in Lua. Lua and associated types shoul be able to assume that LUA_MINSTACK stack slots are available on any user entry point. In the future, we could turn check_stack into something that only checked the Lua stack when debug_assertions is true. --- src/function.rs | 7 ++-- src/lua.rs | 24 ++++++++----- src/table.rs | 91 +++++++++++++++++++++++++++++++------------------ src/thread.rs | 9 ++--- src/userdata.rs | 65 ++++++++++++++++++----------------- 5 files changed, 114 insertions(+), 82 deletions(-) diff --git a/src/function.rs b/src/function.rs index f4484d9..5829ed1 100644 --- a/src/function.rs +++ b/src/function.rs @@ -67,7 +67,7 @@ impl<'lua> Function<'lua> { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; - unsafe { + let results = unsafe { let _sg = StackGuard::new(lua.state); check_stack_err(lua.state, nargs + 3)?; @@ -88,8 +88,9 @@ impl<'lua> Function<'lua> { results.push_front(lua.pop_value()); } ffi::lua_pop(lua.state, 1); - R::from_lua_multi(results, lua) - } + results + }; + R::from_lua_multi(results, lua) } /// Returns a function that, when called, calls `self`, passing `args` as the first set of diff --git a/src/lua.rs b/src/lua.rs index 9dddd4b..3d2dd30 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -184,7 +184,9 @@ impl Lua { { unsafe { let _sg = StackGuard::new(self.state); - check_stack(self.state, 5); + // `Lua` instance assumes that on any callback, the Lua stack has at least LUA_MINSTACK + // slots available to avoid panics. + check_stack_err(self.state, 5 + ffi::LUA_MINSTACK)?; unsafe extern "C" fn new_table(state: *mut ffi::lua_State) -> c_int { ffi::lua_newtable(state); @@ -479,12 +481,13 @@ impl Lua { name: &str, t: T, ) -> Result<()> { + let t = t.to_lua(self)?; unsafe { let _sg = StackGuard::new(self.state); check_stack(self.state, 5); push_string(self.state, name)?; - self.push_value(t.to_lua(self)?); + self.push_value(t); unsafe extern "C" fn set_registry(state: *mut ffi::lua_State) -> c_int { ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); @@ -501,7 +504,7 @@ impl Lua { /// /// [`set_named_registry_value`]: #method.set_named_registry_value pub fn named_registry_value<'lua, T: FromLua<'lua>>(&'lua self, name: &str) -> Result { - unsafe { + let value = unsafe { let _sg = StackGuard::new(self.state); check_stack(self.state, 4); @@ -512,8 +515,9 @@ impl Lua { } protect_lua(self.state, 1, get_registry)?; - T::from_lua(self.pop_value(), self) - } + self.pop_value() + }; + T::from_lua(value, self) } /// Removes a named value in the Lua registry. @@ -530,11 +534,12 @@ impl Lua { /// This value will be available to rust from all `Lua` instances which share the same main /// state. pub fn create_registry_value<'lua, T: ToLua<'lua>>(&'lua self, t: T) -> Result { + let t = t.to_lua(self)?; unsafe { let _sg = StackGuard::new(self.state); check_stack(self.state, 2); - self.push_value(t.to_lua(self)?); + self.push_value(t); let registry_id = gc_guard(self.state, || { ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX) }); @@ -553,7 +558,7 @@ impl Lua { /// /// [`create_registry_value`]: #method.create_registry_value pub fn registry_value<'lua, T: FromLua<'lua>>(&'lua self, key: &RegistryKey) -> Result { - unsafe { + let value = unsafe { if !self.owns_registry_value(key) { return Err(Error::MismatchedRegistryKey); } @@ -566,8 +571,9 @@ impl Lua { ffi::LUA_REGISTRYINDEX, key.registry_id as ffi::lua_Integer, ); - T::from_lua(self.pop_value(), self) - } + self.pop_value() + }; + T::from_lua(value, self) } /// Removes a value from the Lua registry. diff --git a/src/table.rs b/src/table.rs index 3c36a11..608edf0 100644 --- a/src/table.rs +++ b/src/table.rs @@ -51,13 +51,15 @@ impl<'lua> Table<'lua> { /// [`raw_set`]: #method.raw_set pub fn set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { let lua = self.0.lua; + let key = key.to_lua(lua)?; + let value = value.to_lua(lua)?; unsafe { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 6); lua.push_ref(&self.0); - lua.push_value(key.to_lua(lua)?); - lua.push_value(value.to_lua(lua)?); + lua.push_value(key); + lua.push_value(value); unsafe extern "C" fn set_table(state: *mut ffi::lua_State) -> c_int { ffi::lua_settable(state, -3); @@ -97,32 +99,35 @@ impl<'lua> Table<'lua> { /// [`raw_get`]: #method.raw_get pub fn get, V: FromLua<'lua>>(&self, key: K) -> Result { let lua = self.0.lua; - unsafe { + let key = key.to_lua(lua)?; + let value = unsafe { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 5); lua.push_ref(&self.0); - lua.push_value(key.to_lua(lua)?); + lua.push_value(key); unsafe extern "C" fn get_table(state: *mut ffi::lua_State) -> c_int { ffi::lua_gettable(state, -2); 1 } protect_lua(lua.state, 2, get_table)?; - - V::from_lua(lua.pop_value(), lua) - } + lua.pop_value() + }; + V::from_lua(value, lua) } /// Checks whether the table contains a non-nil value for `key`. pub fn contains_key>(&self, key: K) -> Result { let lua = self.0.lua; + let key = key.to_lua(lua)?; + unsafe { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 5); lua.push_ref(&self.0); - lua.push_value(key.to_lua(lua)?); + lua.push_value(key); unsafe extern "C" fn get_table(state: *mut ffi::lua_State) -> c_int { ffi::lua_gettable(state, -2); @@ -138,13 +143,16 @@ impl<'lua> Table<'lua> { /// Sets a key-value pair without invoking metamethods. pub fn raw_set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { let lua = self.0.lua; + let key = key.to_lua(lua)?; + let value = value.to_lua(lua)?; + unsafe { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 6); lua.push_ref(&self.0); - lua.push_value(key.to_lua(lua)?); - lua.push_value(value.to_lua(lua)?); + lua.push_value(key); + lua.push_value(value); unsafe extern "C" fn raw_set(state: *mut ffi::lua_State) -> c_int { ffi::lua_rawset(state, -3); @@ -159,16 +167,17 @@ impl<'lua> Table<'lua> { /// Gets the value associated to `key` without invoking metamethods. pub fn raw_get, V: FromLua<'lua>>(&self, key: K) -> Result { let lua = self.0.lua; - unsafe { + let key = key.to_lua(lua)?; + let value = unsafe { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 3); lua.push_ref(&self.0); - lua.push_value(key.to_lua(lua)?); + lua.push_value(key); ffi::lua_rawget(lua.state, -2); - let res = V::from_lua(lua.pop_value(), lua)?; - Ok(res) - } + lua.pop_value() + }; + V::from_lua(value, lua) } /// Returns the result of the Lua `#` operator. @@ -354,31 +363,39 @@ where if let Some(next_key) = self.next_key.take() { let lua = self.table.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 6); + let res = (|| { + let res = unsafe { + let _sg = StackGuard::new(lua.state); + check_stack(lua.state, 6); - lua.push_ref(&self.table); - lua.push_ref(&next_key); + lua.push_ref(&self.table); + lua.push_ref(&next_key); - match protect_lua_closure(lua.state, 2, ffi::LUA_MULTRET, |state| { - ffi::lua_next(state, -2) != 0 - }) { - Ok(false) => None, - Ok(true) => { + if protect_lua_closure(lua.state, 2, ffi::LUA_MULTRET, |state| { + ffi::lua_next(state, -2) != 0 + })? { ffi::lua_pushvalue(lua.state, -2); let key = lua.pop_value(); let value = lua.pop_value(); self.next_key = Some(lua.pop_ref()); - Some((|| { - let key = K::from_lua(key, lua)?; - let value = V::from_lua(value, lua)?; - Ok((key, value)) - })()) + Some((key, value)) + } else { + None } - Err(e) => Some(Err(e)), - } + }; + + Ok(if let Some((key, value)) = res { + Some((K::from_lua(key, lua)?, V::from_lua(value, lua)?)) + } else { + None + }) + })(); + + match res { + Ok(Some((key, value))) => Some(Ok((key, value))), + Ok(None) => None, + Err(e) => Some(Err(e)), } } else { None @@ -407,7 +424,7 @@ where if let Some(index) = self.index.take() { let lua = self.table.lua; - unsafe { + let res = unsafe { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 5); @@ -418,10 +435,16 @@ where Ok(_) => { let value = lua.pop_value(); self.index = Some(index + 1); - Some(V::from_lua(value, lua)) + Some(Ok(value)) } Err(err) => Some(Err(err)), } + }; + + match res { + Some(Ok(r)) => Some(V::from_lua(r, lua)), + Some(Err(err)) => Some(Err(err)), + None => None, } } else { None diff --git a/src/thread.rs b/src/thread.rs index cb87873..cefde03 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -77,7 +77,8 @@ impl<'lua> Thread<'lua> { R: FromLuaMulti<'lua>, { let lua = self.0.lua; - unsafe { + let args = args.to_lua_multi(lua)?; + let results = unsafe { let _sg = StackGuard::new(lua.state); check_stack(lua.state, 1); @@ -91,7 +92,6 @@ impl<'lua> Thread<'lua> { ffi::lua_pop(lua.state, 1); - let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; check_stack_err(lua.state, nargs)?; check_stack_err(thread_state, nargs + 1)?; @@ -115,8 +115,9 @@ impl<'lua> Thread<'lua> { for _ in 0..nresults { results.push_front(lua.pop_value()); } - R::from_lua_multi(results, lua) - } + results + }; + R::from_lua_multi(results, lua) } /// Gets the status of the thread. diff --git a/src/userdata.rs b/src/userdata.rs index 72e557a..94fd3d0 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -408,6 +408,39 @@ impl<'lua> AnyUserData<'lua> { }) } + /// Sets an associated value to this `AnyUserData`. + /// + /// The value may be any Lua value whatsoever, and can be retrieved with [`get_user_value`]. + /// + /// [`get_user_value`]: #method.get_user_value + pub fn set_user_value>(&self, v: V) -> Result<()> { + let lua = self.0.lua; + let v = v.to_lua(lua)?; + unsafe { + let _sg = StackGuard::new(lua.state); + check_stack(lua.state, 2); + lua.push_ref(&self.0); + lua.push_value(v); + ffi::lua_setuservalue(lua.state, -2); + Ok(()) + } + } + + /// Returns an associated value set by [`set_user_value`]. + /// + /// [`set_user_value`]: #method.set_user_value + pub fn get_user_value>(&self) -> Result { + let lua = self.0.lua; + let res = unsafe { + let _sg = StackGuard::new(lua.state); + check_stack(lua.state, 3); + lua.push_ref(&self.0); + ffi::lua_getuservalue(lua.state, -1); + lua.pop_value() + }; + V::from_lua(res, lua) + } + fn inspect<'a, T, R, F>(&'a self, func: F) -> Result where T: UserData, @@ -437,36 +470,4 @@ impl<'lua> AnyUserData<'lua> { } } } - - /// Sets an associated value to this `AnyUserData`. - /// - /// The value may be any Lua value whatsoever, and can be retrieved with [`get_user_value`]. - /// - /// [`get_user_value`]: #method.get_user_value - pub fn set_user_value>(&self, v: V) -> Result<()> { - let lua = self.0.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2); - lua.push_ref(&self.0); - lua.push_value(v.to_lua(lua)?); - ffi::lua_setuservalue(lua.state, -2); - Ok(()) - } - } - - /// Returns an associated value set by [`set_user_value`]. - /// - /// [`set_user_value`]: #method.set_user_value - pub fn get_user_value>(&self) -> Result { - let lua = self.0.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 3); - lua.push_ref(&self.0); - ffi::lua_getuservalue(lua.state, -1); - let res = V::from_lua(lua.pop_value(), lua)?; - Ok(res) - } - } }