diff --git a/src/error.rs b/src/error.rs index 2717530..fbb6a1d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,17 +26,25 @@ pub enum Error { /// /// The Lua VM returns this error when there is an error running a `__gc` metamethod. GarbageCollectorError(String), - /// A callback has triggered Lua code that has called the same callback again. + /// A mutable callback has triggered Lua code that has called the same mutable callback again. /// - /// This is an error because `rlua` callbacks are FnMut and thus can only be mutably borrowed - /// once. - RecursiveCallback, + /// This is an error because a mutable callback can only be borrowed mutably once. + RecursiveMutCallback, /// Either a callback or a userdata method has been called, but the callback or userdata has /// been destructed. /// /// This can happen either due to to being destructed in a previous __gc, or due to being /// destructed from exiting a `Lua::scope` call. CallbackDestructed, + /// Not enough stack space to place arguments to Lua functions or return values from callbacks. + /// + /// Due to the way `rlua` works, it should not be directly possible to run out of stack space + /// during normal use. The only way that this error can be triggered is if a `Function` is + /// called with a huge number of arguments, or a rust callback returns a huge number of return + /// values. + StackError, + /// Too many arguments to `Function::bind` + BindError, /// A Rust value could not be converted to a Lua value. ToLuaConversionError { /// Name of the Rust type that could not be converted. @@ -123,11 +131,19 @@ impl fmt::Display for Error { Error::GarbageCollectorError(ref msg) => { write!(fmt, "garbage collector error: {}", msg) } - Error::RecursiveCallback => write!(fmt, "callback called recursively"), + Error::RecursiveMutCallback => write!(fmt, "mutable callback called recursively"), Error::CallbackDestructed => write!( fmt, "a destructed callback or destructed userdata method was called" ), + Error::StackError => write!( + fmt, + "out of Lua stack, too many arguments to a Lua function or too many return values from a callback" + ), + Error::BindError => write!( + fmt, + "too many arguments to Function::bind" + ), Error::ToLuaConversionError { from, to, diff --git a/src/function.rs b/src/function.rs index 9f4f617..5537a45 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,3 +1,4 @@ +use std::ptr; use std::os::raw::c_int; use ffi; @@ -65,7 +66,7 @@ impl<'lua> Function<'lua> { stack_err_guard(lua.state, 0, || { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; - check_stack(lua.state, nargs + 3); + check_stack_err(lua.state, nargs + 3)?; ffi::lua_pushcfunction(lua.state, error_traceback); let stack_start = ffi::lua_gettop(lua.state); @@ -124,7 +125,7 @@ impl<'lua> Function<'lua> { unsafe extern "C" fn bind_call_impl(state: *mut ffi::lua_State) -> c_int { let nargs = ffi::lua_gettop(state); let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(2)) as c_int; - check_stack(state, nbinds + 2); + ffi::luaL_checkstack(state, nbinds + 2, ptr::null()); ffi::lua_settop(state, nargs + nbinds + 1); ffi::lua_rotate(state, -(nargs + nbinds + 1), nbinds + 1); @@ -144,10 +145,16 @@ impl<'lua> Function<'lua> { let lua = self.0.lua; unsafe { stack_err_guard(lua.state, 0, || { + const MAX_LUA_UPVALUES: c_int = 255; + let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; - check_stack(lua.state, nargs + 3); + if nargs > MAX_LUA_UPVALUES { + return Err(Error::BindError); + } + + check_stack_err(lua.state, nargs + 3)?; lua.push_ref(lua.state, &self.0); ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer); for arg in args { diff --git a/src/lua.rs b/src/lua.rs index 1abc8b0..455a4b4 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -1,6 +1,5 @@ use std::{mem, process, ptr, str}; use std::sync::{Arc, Mutex}; -use std::ops::DerefMut; use std::cell::RefCell; use std::ffi::CString; use std::any::{Any, TypeId}; @@ -32,7 +31,7 @@ pub struct Lua { /// !Send, and callbacks that are !Send and not 'static. pub struct Scope<'lua, 'scope> { lua: &'lua Lua, - destructors: RefCell Box>>>, + destructors: RefCell Box>>>, // 'scope lifetime must be invariant _scope: PhantomData<&'scope mut &'scope ()>, } @@ -265,18 +264,31 @@ impl Lua { /// /// [`ToLua`]: trait.ToLua.html /// [`ToLuaMulti`]: trait.ToLuaMulti.html - pub fn create_function<'lua, 'callback, A, R, F>( + pub fn create_function<'lua, 'callback, A, R, F>(&'lua self, func: F) -> Result> + where + A: FromLuaMulti<'callback>, + R: ToLuaMulti<'callback>, + F: 'static + Send + Fn(&'callback Lua, A) -> Result, + { + self.create_callback_function(Box::new(move |lua, args| { + func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })) + } + + pub fn create_function_mut<'lua, 'callback, A, R, F>( &'lua self, - mut func: F, + func: F, ) -> Result> where A: FromLuaMulti<'callback>, R: ToLuaMulti<'callback>, F: 'static + Send + FnMut(&'callback Lua, A) -> Result, { - self.create_callback_function(Box::new(move |lua, args| { - func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })) + let func = RefCell::new(func); + self.create_function(move |lua, args| { + (&mut *func.try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?)(lua, args) + }) } /// Wraps a Lua function into a new thread (or coroutine). @@ -354,7 +366,7 @@ impl Lua { Value::String(s) => Ok(s), v => unsafe { stack_err_guard(self.state, 0, || { - check_stack(self.state, 2); + check_stack(self.state, 4); let ty = v.type_name(); self.push_value(self.state, v); let s = @@ -383,7 +395,7 @@ impl Lua { Value::Integer(i) => Ok(i), v => unsafe { stack_guard(self.state, 0, || { - check_stack(self.state, 1); + check_stack(self.state, 2); let ty = v.type_name(); self.push_value(self.state, v); let mut isint = 0; @@ -412,7 +424,7 @@ impl Lua { Value::Number(n) => Ok(n), v => unsafe { stack_guard(self.state, 0, || { - check_stack(self.state, 1); + check_stack(self.state, 2); let ty = v.type_name(); self.push_value(self.state, v); let mut isnum = 0; @@ -511,7 +523,7 @@ impl Lua { pub fn create_registry_value<'lua, T: ToLua<'lua>>(&'lua self, t: T) -> Result { unsafe { stack_guard(self.state, 0, || { - check_stack(self.state, 1); + check_stack(self.state, 2); self.push_value(self.state, t.to_lua(self)?); let registry_id = gc_guard(self.state, || { @@ -592,7 +604,7 @@ impl Lua { } } - // Uses 1 stack space, does not call checkstack + // Uses 2 stack spaces, does not call checkstack pub(crate) unsafe fn push_value(&self, state: *mut ffi::lua_State, value: Value) { match value { Value::Nil => { @@ -730,7 +742,7 @@ impl Lua { // Used if both an __index metamethod is set and regular methods, checks methods table // first, then __index metamethod. unsafe extern "C" fn meta_index_impl(state: *mut ffi::lua_State) -> c_int { - check_stack(state, 2); + ffi::luaL_checkstack(state, 2, ptr::null()); ffi::lua_pushvalue(state, -1); ffi::lua_gettable(state, ffi::lua_upvalueindex(1)); @@ -922,7 +934,7 @@ impl Lua { ffi::lua_newtable(state); push_string(state, "__gc").unwrap(); - ffi::lua_pushcfunction(state, userdata_destructor::>); + ffi::lua_pushcfunction(state, userdata_destructor::); ffi::lua_rawset(state, -3); push_string(state, "__metatable").unwrap(); @@ -977,10 +989,7 @@ impl Lua { return Err(Error::CallbackDestructed); } - let func = get_userdata::>(state, ffi::lua_upvalueindex(1)); - let mut func = (*func) - .try_borrow_mut() - .map_err(|_| Error::RecursiveCallback)?; + let func = get_userdata::(state, ffi::lua_upvalueindex(1)); let nargs = ffi::lua_gettop(state); let mut args = MultiValue::new(); @@ -989,10 +998,10 @@ impl Lua { args.push_front(lua.pop_value(state)); } - let results = func.deref_mut()(&lua, args)?; + let results = (*func)(&lua, args)?; let nresults = results.len() as c_int; - check_stack(state, nresults); + check_stack_err(state, nresults)?; for r in results { lua.push_value(state, r); @@ -1006,7 +1015,7 @@ impl Lua { stack_err_guard(self.state, 0, move || { check_stack(self.state, 2); - push_userdata::>(self.state, RefCell::new(func))?; + push_userdata::(self.state, func)?; ffi::lua_pushlightuserdata( self.state, @@ -1053,15 +1062,15 @@ impl Lua { } impl<'lua, 'scope> Scope<'lua, 'scope> { - pub fn create_function<'callback, A, R, F>(&self, mut func: F) -> Result> + pub fn create_function<'callback, A, R, F>(&self, func: F) -> Result> where A: FromLuaMulti<'callback>, R: ToLuaMulti<'callback>, - F: 'scope + FnMut(&'callback Lua, A) -> Result, + F: 'scope + Fn(&'callback Lua, A) -> Result, { unsafe { let f: Box< - FnMut(&'callback Lua, MultiValue<'callback>) -> Result>, + Fn(&'callback Lua, MultiValue<'callback>) -> Result>, > = Box::new(move |lua, args| { func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) }); @@ -1082,7 +1091,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { ffi::luaL_unref(state, ffi::LUA_REGISTRYINDEX, registry_id); ffi::lua_getupvalue(state, -1, 1); - let ud = take_userdata::>(state); + let ud = take_userdata::(state); ffi::lua_pushnil(state); ffi::lua_setupvalue(state, -2, 1); @@ -1094,6 +1103,19 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { } } + pub fn create_function_mut<'callback, A, R, F>(&self, func: F) -> Result> + where + A: FromLuaMulti<'callback>, + R: ToLuaMulti<'callback>, + F: 'scope + FnMut(&'callback Lua, A) -> Result, + { + let func = RefCell::new(func); + self.create_function(move |lua, args| { + (&mut *func.try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?)(lua, args) + }) + } + pub fn create_userdata(&self, data: T) -> Result> where T: UserData, @@ -1128,7 +1150,7 @@ impl<'lua, 'scope> Drop for Scope<'lua, 'scope> { let to_drop = self.destructors .get_mut() .drain(..) - .map(|mut destructor| destructor(state)) + .map(|destructor| destructor(state)) .collect::>(); drop(to_drop); } diff --git a/src/table.rs b/src/table.rs index 29af395..cadb077 100644 --- a/src/table.rs +++ b/src/table.rs @@ -125,7 +125,7 @@ impl<'lua> Table<'lua> { let lua = self.0.lua; unsafe { stack_err_guard(lua.state, 0, || { - check_stack(lua.state, 3); + check_stack(lua.state, 6); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?); lua.push_value(lua.state, value.to_lua(lua)?); @@ -142,7 +142,7 @@ impl<'lua> Table<'lua> { let lua = self.0.lua; unsafe { stack_err_guard(lua.state, 0, || { - check_stack(lua.state, 2); + check_stack(lua.state, 3); lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?); ffi::lua_rawget(lua.state, -2); diff --git a/src/tests.rs b/src/tests.rs index 63e5d26..1cd5970 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,5 +1,5 @@ -use std::fmt; -use std::error; +use std::{error, fmt}; +use std::iter::FromIterator; use std::rc::Rc; use std::cell::Cell; use std::sync::Arc; @@ -432,11 +432,11 @@ fn test_pcall_xpcall() { } #[test] -fn test_recursive_callback_error() { +fn test_recursive_mut_callback_error() { let lua = Lua::new(); let mut v = Some(Box::new(123)); - let f = lua.create_function::<_, (), _>(move |lua, mutate: bool| { + let f = lua.create_function_mut::<_, (), _>(move |lua, mutate: bool| { if mutate { v = None; } else { @@ -459,7 +459,7 @@ fn test_recursive_callback_error() { { Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { Error::CallbackError { ref cause, .. } => match *cause.as_ref() { - Error::RecursiveCallback { .. } => {} + Error::RecursiveMutCallback { .. } => {} ref other => panic!("incorrect result: {:?}", other), }, ref other => panic!("incorrect result: {:?}", other), @@ -527,7 +527,7 @@ fn test_registry_value() { let lua = Lua::new(); let mut r = Some(lua.create_registry_value::(42).unwrap()); - let f = lua.create_function(move |lua, ()| { + let f = lua.create_function_mut(move |lua, ()| { if let Some(r) = r.take() { assert_eq!(lua.registry_value::(&r)?, 42); lua.remove_registry_value(r).unwrap(); @@ -666,7 +666,7 @@ fn scope_capture() { let lua = Lua::new(); lua.scope(|scope| { scope - .create_function(|_, ()| { + .create_function_mut(|_, ()| { i = 42; Ok(()) }) @@ -677,6 +677,67 @@ fn scope_capture() { assert_eq!(i, 42); } +#[test] +fn too_many_returns() { + let lua = Lua::new(); + let f = lua.create_function(|_, ()| Ok(Variadic::from_iter(1..1000000))) + .unwrap(); + assert!(f.call::<_, Vec>(()).is_err()); +} + +#[test] +fn too_many_arguments() { + let lua = Lua::new(); + lua.exec::<()>("function test(...) end", None).unwrap(); + let args = Variadic::from_iter(1..1000000); + assert!( + lua.globals() + .get::<_, Function>("test") + .unwrap() + .call::<_, ()>(args) + .is_err() + ); +} + +#[test] +fn too_many_recursions() { + let lua = Lua::new(); + + let f = lua.create_function(move |lua, ()| { + lua.globals().get::<_, Function>("f")?.call::<_, ()>(()) + }).unwrap(); + lua.globals().set("f", f).unwrap(); + + assert!( + lua.globals() + .get::<_, Function>("f") + .unwrap() + .call::<_, ()>(()) + .is_err() + ); +} + +#[test] +fn too_many_binds() { + let lua = Lua::new(); + let globals = lua.globals(); + lua.exec::<()>( + r#" + function f(...) + end + "#, + None, + ).unwrap(); + + let concat = globals.get::<_, Function>("f").unwrap(); + assert!(concat.bind(Variadic::from_iter(1..1000000)).is_err()); + assert!( + concat + .call::<_, ()>(Variadic::from_iter(1..1000000)) + .is_err() + ); +} + // TODO: Need to use compiletest-rs or similar to make sure these don't compile. /* #[test] diff --git a/src/thread.rs b/src/thread.rs index 59ab2e7..1ab62b2 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -93,7 +93,7 @@ impl<'lua> Thread<'lua> { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; - check_stack(thread_state, nargs); + check_stack_err(thread_state, nargs + 1)?; for arg in args { lua.push_value(thread_state, arg); diff --git a/src/types.rs b/src/types.rs index 5766f72..1fb9f30 100644 --- a/src/types.rs +++ b/src/types.rs @@ -39,8 +39,7 @@ impl Drop for RegistryKey { } } -pub(crate) type Callback<'lua> = - Box) -> Result>>; +pub(crate) type Callback<'lua> = Box) -> Result>>; pub(crate) struct LuaRef<'lua> { pub lua: &'lua Lua, diff --git a/src/userdata.rs b/src/userdata.rs index 8c7a92f..a870d71 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -89,7 +89,7 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + for<'a> FnMut(&'lua Lua, &'a T, A) -> Result, + M: 'static + Send + for<'a> Fn(&'lua Lua, &'a T, A) -> Result, { self.methods .insert(name.to_owned(), Self::box_method(method)); @@ -121,7 +121,7 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, { self.methods .insert(name.to_owned(), Self::box_function(function)); @@ -139,7 +139,7 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + for<'a> FnMut(&'lua Lua, &'a T, A) -> Result, + M: 'static + Send + for<'a> Fn(&'lua Lua, &'a T, A) -> Result, { self.meta_methods.insert(meta, Self::box_method(method)); } @@ -170,25 +170,25 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, { self.meta_methods.insert(meta, Self::box_function(function)); } - fn box_function(mut function: F) -> Callback<'lua> + fn box_function(function: F) -> Callback<'lua> where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, { Box::new(move |lua, args| function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)) } - fn box_method(mut method: M) -> Callback<'lua> + fn box_method(method: M) -> Callback<'lua> where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + for<'a> FnMut(&'lua Lua, &'a T, A) -> Result, + M: 'static + Send + for<'a> Fn(&'lua Lua, &'a T, A) -> Result, { Box::new(move |lua, mut args| { if let Some(front) = args.pop_front() { @@ -205,17 +205,21 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { }) } - fn box_method_mut(mut method: M) -> Callback<'lua> + fn box_method_mut(method: M) -> Callback<'lua> where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + Send + for<'a> FnMut(&'lua Lua, &'a mut T, A) -> Result, { + let method = RefCell::new(method); Box::new(move |lua, mut args| { if let Some(front) = args.pop_front() { let userdata = AnyUserData::from_lua(front, lua)?; let mut userdata = userdata.borrow_mut::()?; - method(lua, &mut userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + let mut method = method + .try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?; + (&mut *method)(lua, &mut userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) } else { Err(Error::FromLuaConversionError { from: "missing argument", diff --git a/src/util.rs b/src/util.rs index a7035e6..0020d52 100644 --- a/src/util.rs +++ b/src/util.rs @@ -8,8 +8,8 @@ use std::panic::{catch_unwind, resume_unwind, UnwindSafe}; use ffi; use error::{Error, Result}; -// Checks that Lua has enough free stack space for future stack operations. -// On failure, this will clear the stack and panic. +// Checks that Lua has enough free stack space for future stack operations. On failure, this will +// clear the stack and panic. pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) { lua_internal_assert!( state, @@ -18,6 +18,16 @@ pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) { ); } +// Similar to `check_stack`, but returns `Error::StackError` on failure. Useful for user controlled +// sizes, which should not cause a panic. +pub unsafe fn check_stack_err(state: *mut ffi::lua_State, amount: c_int) -> Result<()> { + if ffi::lua_checkstack(state, amount) == 0 { + Err(Error::StackError) + } else { + Ok(()) + } +} + // Run an operation on a lua_State and check that the stack change is what is // expected. If the stack change does not match, clears the stack and panics. pub unsafe fn stack_guard(state: *mut ffi::lua_State, change: c_int, op: F) -> R @@ -279,10 +289,12 @@ where match catch_unwind(f) { Ok(Ok(r)) => r, Ok(Err(err)) => { + ffi::luaL_checkstack(state, 2, ptr::null()); push_wrapped_error(state, err); ffi::lua_error(state) } Err(p) => { + ffi::luaL_checkstack(state, 2, ptr::null()); push_wrapped_panic(state, p); ffi::lua_error(state) } @@ -293,6 +305,8 @@ where // Error::CallbackError with a traceback, if it is some lua type, prints the error along with a // traceback, and if it is a WrappedPanic, does not modify it. pub unsafe extern "C" fn error_traceback(state: *mut ffi::lua_State) -> c_int { + ffi::luaL_checkstack(state, 2, ptr::null()); + if let Some(error) = pop_wrapped_error(state) { ffi::luaL_traceback(state, state, ptr::null(), 0); let traceback = CStr::from_ptr(ffi::lua_tostring(state, -1)) @@ -386,10 +400,9 @@ pub unsafe fn main_state(state: *mut ffi::lua_State) -> *mut ffi::lua_State { main_state } -// Pushes a WrappedError::Error to the top of the stack +// Pushes a WrappedError::Error to the top of the stack. Uses two stack spaces and does not call +// lua_checkstack. pub unsafe fn push_wrapped_error(state: *mut ffi::lua_State, err: Error) { - ffi::luaL_checkstack(state, 2, ptr::null()); - gc_guard(state, || { let ud = ffi::lua_newuserdata(state, mem::size_of::()) as *mut WrappedError; ptr::write(ud, WrappedError(err)) @@ -432,10 +445,9 @@ pub unsafe fn gc_guard R>(state: *mut ffi::lua_State, f: F) -> struct WrappedError(pub Error); struct WrappedPanic(pub Option>); -// Pushes a WrappedError::Panic to the top of the stack +// Pushes a WrappedError::Panic to the top of the stack. Uses two stack spaces and does not call +// lua_checkstack. unsafe fn push_wrapped_panic(state: *mut ffi::lua_State, panic: Box) { - ffi::luaL_checkstack(state, 2, ptr::null()); - gc_guard(state, || { let ud = ffi::lua_newuserdata(state, mem::size_of::()) as *mut WrappedPanic; ptr::write(ud, WrappedPanic(Some(panic))) @@ -598,6 +610,7 @@ unsafe fn get_destructed_userdata_metatable(state: *mut ffi::lua_State) -> c_int static DESTRUCTED_USERDATA_METATABLE: u8 = 0; unsafe extern "C" fn destructed_error(state: *mut ffi::lua_State) -> c_int { + ffi::luaL_checkstack(state, 2, ptr::null()); push_wrapped_error(state, Error::CallbackDestructed); ffi::lua_error(state) }