From 47d4ea62ff152efab8a98104c81ff92716ed74b6 Mon Sep 17 00:00:00 2001 From: kyren Date: Mon, 5 Jun 2017 00:03:39 -0400 Subject: [PATCH] Handle unprotected lua errors SOMEWHAT more elegantly There should be drastically less ways to cause unprotected lua errors now, as the LuaTable functions which were trivial to cause unprotected errors are now protected. Unfortunately, they are protected in a pretty slow, terrible way right now, but it at least works. Also, set the atpanic function in lua to call a proper rust panic instead. --- src/ffi.rs | 3 ++ src/lua.rs | 95 +++++++++++++++++++++++++++++++------------- src/tests.rs | 108 ++++++++++++++++++++++++++++++++++++++------------- src/util.rs | 68 +++++++++++++++++++++++++++----- 4 files changed, 211 insertions(+), 63 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index 0bdf0d0..1d210e9 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -102,6 +102,7 @@ extern "C" { pub fn lua_newthread(state: *mut lua_State) -> *mut lua_State; pub fn lua_settable(state: *mut lua_State, index: c_int); + pub fn lua_rawset(state: *mut lua_State, index: c_int); pub fn lua_setmetatable(state: *mut lua_State, index: c_int); pub fn lua_len(state: *mut lua_State, index: c_int); @@ -109,6 +110,7 @@ extern "C" { pub fn lua_rawequal(state: *mut lua_State, index1: c_int, index2: c_int) -> c_int; pub fn lua_error(state: *mut lua_State) -> !; + pub fn lua_atpanic(state: *mut lua_State, panic: lua_CFunction) -> lua_CFunction; pub fn luaL_newstate() -> *mut lua_State; pub fn luaL_openlibs(state: *mut lua_State); @@ -125,6 +127,7 @@ extern "C" { state: *mut lua_State, msg: *const c_char, level: c_int); + pub fn luaL_len(push_state: *mut lua_State, index: c_int) -> lua_Integer; } pub unsafe fn lua_pop(state: *mut lua_State, n: c_int) { diff --git a/src/lua.rs b/src/lua.rs index 8ecfd20..cf668a8 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -158,7 +158,24 @@ impl<'lua> LuaTable<'lua> { lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?)?; lua.push_value(lua.state, value.to_lua(lua)?)?; - ffi::lua_settable(lua.state, -3); + error_guard(lua.state, 3, 0, |state| { + ffi::lua_settable(state, -3); + Ok(()) + })?; + Ok(()) + }) + } + } + + pub fn raw_set, V: ToLua<'lua>>(&self, key: K, value: V) -> LuaResult<()> { + let lua = self.0.lua; + unsafe { + stack_guard(lua.state, 0, || { + check_stack(lua.state, 3)?; + lua.push_ref(lua.state, &self.0); + lua.push_value(lua.state, key.to_lua(lua)?)?; + lua.push_value(lua.state, value.to_lua(lua)?)?; + ffi::lua_rawset(lua.state, -3); ffi::lua_pop(lua.state, 1); Ok(()) }) @@ -166,6 +183,24 @@ impl<'lua> LuaTable<'lua> { } pub fn get, V: FromLua<'lua>>(&self, key: K) -> LuaResult { + let lua = self.0.lua; + unsafe { + stack_guard(lua.state, 0, || { + check_stack(lua.state, 2)?; + lua.push_ref(lua.state, &self.0); + lua.push_value(lua.state, key.to_lua(lua)?)?; + error_guard(lua.state, 2, 2, |state| { + ffi::lua_gettable(state, -2); + Ok(()) + })?; + let res = V::from_lua(lua.pop_value(lua.state)?, lua)?; + ffi::lua_pop(lua.state, 1); + Ok(res) + }) + } + } + + pub fn raw_get, V: FromLua<'lua>>(&self, key: K) -> LuaResult { let lua = self.0.lua; unsafe { stack_guard(lua.state, 0, || { @@ -187,10 +222,7 @@ impl<'lua> LuaTable<'lua> { stack_guard(lua.state, 0, || { check_stack(lua.state, 1)?; lua.push_ref(lua.state, &self.0); - ffi::lua_len(lua.state, -1); - let len = ffi::lua_tointeger(lua.state, -1); - ffi::lua_pop(lua.state, 2); - Ok(len) + error_guard(lua.state, 1, 0, |state| Ok(ffi::luaL_len(state, -1))) }) } } @@ -232,10 +264,7 @@ impl<'lua> LuaTable<'lua> { check_stack(lua.state, 4)?; lua.push_ref(lua.state, &self.0); - ffi::lua_len(lua.state, -1); - let len = ffi::lua_tointeger(lua.state, -1); - ffi::lua_pop(lua.state, 1); - + let len = error_guard(lua.state, 1, 1, |state| Ok(ffi::luaL_len(state, -1)))?; ffi::lua_pushnil(lua.state); while ffi::lua_next(lua.state, -2) != 0 { @@ -290,7 +319,7 @@ impl<'lua> LuaFunction<'lua> { stack_guard(lua.state, 0, || { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; - check_stack(lua.state, nargs + 1)?; + check_stack(lua.state, nargs + 3)?; let stack_start = ffi::lua_gettop(lua.state); lua.push_ref(lua.state, &self.0); @@ -615,6 +644,15 @@ impl Lua { pub fn new() -> Lua { unsafe { let state = ffi::luaL_newstate(); + unsafe extern "C" fn panic_function(state: *mut ffi::lua_State) -> c_int { + if let Some(s) = ffi::lua_tostring(state, -1).as_ref() { + panic!("rlua - unprotected error in call to Lua API ({})", s) + } else { + panic!("rlua - unprotected error in call to Lua API ") + } + } + + ffi::lua_atpanic(state, panic_function); ffi::luaL_openlibs(state); stack_guard(state, 0, || { @@ -631,11 +669,11 @@ impl Lua { push_string(state, "__gc"); ffi::lua_pushcfunction(state, destructor::>>); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); ffi::lua_setmetatable(state, -2); - ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) .unwrap(); @@ -649,13 +687,13 @@ impl Lua { push_string(state, "__gc"); ffi::lua_pushcfunction(state, destructor::); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); push_string(state, "__metatable"); ffi::lua_pushboolean(state, 0); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); - ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) .unwrap(); @@ -665,11 +703,11 @@ impl Lua { push_string(state, "pcall"); ffi::lua_pushcfunction(state, safe_pcall); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); push_string(state, "xpcall"); ffi::lua_pushcfunction(state, safe_xpcall); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); ffi::lua_pop(state, 1); Ok(()) @@ -680,7 +718,7 @@ impl Lua { ffi::lua_pushlightuserdata(state, &TOP_STATE_REGISTRY_KEY as *const u8 as *mut c_void); ffi::lua_pushlightuserdata(state, state as *mut c_void); - ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) .unwrap(); @@ -710,6 +748,7 @@ impl Lua { ptr::null()) })?; + check_stack(self.state, 2)?; handle_error(self.state, pcall_with_traceback(self.state, 0, 0)) }) } @@ -738,6 +777,7 @@ impl Lua { handle_error(self.state, res)?; + check_stack(self.state, 2)?; handle_error(self.state, pcall_with_traceback(self.state, 0, ffi::LUA_MULTRET))?; @@ -849,7 +889,7 @@ impl Lua { ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); self.push_value(self.state, key.to_lua(self)?)?; self.push_value(self.state, value.to_lua(self)?)?; - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); ffi::lua_pop(self.state, 1); Ok(()) }) @@ -1170,10 +1210,10 @@ impl Lua { push_string(self.state, &k); self.push_value(self.state, LuaValue::Function(self.create_callback_function(m)?))?; - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } check_stack(self.state, methods.meta_methods.len() as c_int * 2)?; @@ -1185,7 +1225,7 @@ impl Lua { self.push_value(self.state, LuaValue::Function(self.create_callback_function(m)?))?; ffi::lua_pushcclosure(self.state, meta_index_impl, 2); - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } else { let name = match k { LuaMetaMethod::Add => "__add", @@ -1207,17 +1247,17 @@ impl Lua { push_string(self.state, name); self.push_value(self.state, LuaValue::Function(self.create_callback_function(m)?))?; - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } } push_string(self.state, "__gc"); ffi::lua_pushcfunction(self.state, destructor::>); - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); push_string(self.state, "__metatable"); ffi::lua_pushboolean(self.state, 0); - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); let id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX); entry.insert(id); @@ -1232,8 +1272,9 @@ static LUA_USERDATA_REGISTRY_KEY: u8 = 0; static FUNCTION_METATABLE_REGISTRY_KEY: u8 = 0; static TOP_STATE_REGISTRY_KEY: u8 = 0; -// If the return code is not LUA_OK, pops the error off of the stack and returns Err. If the error -// was actually a rust panic, clears the current lua stack and panics. +// If the return code indicates an error, pops the error off of the stack and +// returns Err. If the error was actually a rust panic, clears the current lua +// stack and panics. unsafe fn handle_error(state: *mut ffi::lua_State, ret: c_int) -> LuaResult<()> { if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { Err(pop_error(state)) diff --git a/src/tests.rs b/src/tests.rs index ca0c5ed..1832616 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -18,10 +18,12 @@ fn test_set_get() { #[test] fn test_load() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" res = 'foo'..'bar' "#, - None) + None, + ) .unwrap(); assert_eq!(lua.get::<_, String>("res").unwrap(), "foobar"); } @@ -52,12 +54,14 @@ fn test_table() { assert_eq!(table2.get::<_, String>("foo").unwrap(), "bar"); assert_eq!(table1.get::<_, String>("baz").unwrap(), "baf"); - lua.load(r#" + lua.load( + r#" table1 = {1, 2, 3, 4, 5} table2 = {} table3 = {1, 2, nil, 4, 5} "#, - None) + None, + ) .unwrap(); let table1 = lua.get::<_, LuaTable>("table1").unwrap(); @@ -78,12 +82,14 @@ fn test_table() { #[test] fn test_function() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function concat(arg1, arg2) return arg1 .. arg2 end "#, - None) + None, + ) .unwrap(); let concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -94,7 +100,8 @@ fn test_function() { #[test] fn test_bind() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function concat(...) local res = "" for _, s in pairs({...}) do @@ -103,7 +110,8 @@ fn test_bind() { return res end "#, - None) + None, + ) .unwrap(); let mut concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -117,7 +125,8 @@ fn test_bind() { #[test] fn test_rust_function() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function lua_function() return rust_function() end @@ -125,7 +134,8 @@ fn test_rust_function() { -- Test to make sure chunk return is ignored return 1 "#, - None) + None, + ) .unwrap(); let lua_function = lua.get::<_, LuaFunction>("lua_function").unwrap(); @@ -174,7 +184,8 @@ fn test_methods() { let lua = Lua::new(); let userdata = lua.create_userdata(UserData(42)).unwrap(); lua.set("userdata", userdata.clone()).unwrap(); - lua.load(r#" + lua.load( + r#" function get_it() return userdata:get_value() end @@ -183,7 +194,8 @@ fn test_methods() { return userdata:set_value(i) end "#, - None) + None, + ) .unwrap(); let get = lua.get::<_, LuaFunction>("get_it").unwrap(); let set = lua.get::<_, LuaFunction>("set_it").unwrap(); @@ -234,12 +246,14 @@ fn test_metamethods() { #[test] fn test_scope() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" touter = { tin = {1, 2, 3} } "#, - None) + None, + ) .unwrap(); // Make sure that table gets do not borrow the table, but instead just borrow lua. @@ -268,7 +282,8 @@ fn test_scope() { #[test] fn test_lua_multi() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function concat(arg1, arg2) return arg1 .. arg2 end @@ -277,7 +292,8 @@ fn test_lua_multi() { return 1, 2, 3, 4, 5, 6 end "#, - None) + None, + ) .unwrap(); let concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -295,12 +311,14 @@ fn test_lua_multi() { #[test] fn test_coercion() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" int = 123 str = "123" num = 123.0 "#, - None) + None, + ) .unwrap(); assert_eq!(lua.get::<_, String>("int").unwrap(), "123"); @@ -330,7 +348,8 @@ fn test_error() { } let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function no_error() end @@ -368,7 +387,8 @@ fn test_error() { understand_recursion() end "#, - None) + None, + ) .unwrap(); let rust_error_function = @@ -400,12 +420,14 @@ fn test_error() { match catch_unwind(|| -> LuaResult<()> { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function rust_panic() pcall(function () rust_panic_function() end) end "#, - None)?; + None, + )?; let rust_panic_function = lua.create_function(|_, _| { panic!("expected panic, this panic should be caught in rust") })?; @@ -422,12 +444,14 @@ fn test_error() { match catch_unwind(|| -> LuaResult<()> { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function rust_panic() xpcall(function() rust_panic_function() end, function() end) end "#, - None)?; + None, + )?; let rust_panic_function = lua.create_function(|_, _| { panic!("expected panic, this panic should be caught in rust") })?; @@ -497,10 +521,12 @@ fn test_thread() { #[test] fn test_lightuserdata() { let lua = Lua::new(); - lua.load(r#"function id(a) + lua.load( + r#"function id(a) return a end"#, - None) + None, + ) .unwrap(); let res = lua.get::<_, LuaFunction>("id") .unwrap() @@ -508,3 +534,33 @@ fn test_lightuserdata() { .unwrap(); assert_eq!(res, LightUserData(42 as *mut c_void)); } + +#[test] +fn test_table_error() { + let lua = Lua::new(); + lua.load( + r#" + table = {} + setmetatable(table, { + __index = function() + error("lua error") + end, + __newindex = function() + error("lua error") + end, + __len = function() + error("lua error") + end + }) + "#, + None, + ) + .unwrap(); + + let bad_table: LuaTable = lua.get("table").unwrap(); + assert!(bad_table.set("key", 1).is_err()); + assert!(bad_table.get::<_, i32>("key").is_err()); + assert!(bad_table.length().is_err()); + assert!(bad_table.raw_set("key", 1).is_ok()); + assert!(bad_table.raw_get::<_, i32>("key").is_ok()); +} diff --git a/src/util.rs b/src/util.rs index a58a699..2ba263c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -14,6 +14,14 @@ macro_rules! cstr { ); } +pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> LuaResult<()> { + if ffi::lua_checkstack(state, amount) == 0 { + Err("out of lua stack space".into()) + } else { + Ok(()) + } +} + // Run an operation on a lua_State and automatically clean up the stack before returning. Takes // the lua_State, the expected stack size change, and an operation to run. If the operation // results in success, then the stack is inspected to make sure the change in stack size matches @@ -46,12 +54,52 @@ pub unsafe fn stack_guard(state: *mut ffi::lua_State, change: c_int, op: F res } -pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> LuaResult<()> { - if ffi::lua_checkstack(state, amount) == 0 { - Err("out of lua stack space".into()) - } else { - Ok(()) +// Call the given rust function in a protected lua context, similar to pcall. +// The stack given to the protected function is a separate protected stack. This +// catches all calls to lua_error, but ffi functions that can call lua_error are +// still longjmps, and have all the same dangers as longjmps, so extreme care +// must still be taken in code that uses this function. Does not call +// lua_checkstack, and uses 2 extra stack spaces. +pub unsafe fn error_guard(state: *mut ffi::lua_State, + nargs: c_int, + nresults: c_int, + func: F) + -> LuaResult + where F: FnOnce(*mut ffi::lua_State) -> LuaResult + UnwindSafe +{ + unsafe extern "C" fn call_impl(state: *mut ffi::lua_State) -> c_int + where F: FnOnce(*mut ffi::lua_State) -> c_int + { + let func = ffi::lua_touserdata(state, -1) as *mut F; + let func = mem::replace(&mut *func, mem::uninitialized()); + ffi::lua_pop(state, 1); + func(state) } + + pub unsafe fn cpcall(state: *mut ffi::lua_State, + nargs: c_int, + nresults: c_int, + mut func: F) + -> LuaResult<()> + where F: FnOnce(*mut ffi::lua_State) -> c_int + { + ffi::lua_pushcfunction(state, call_impl::); + ffi::lua_insert(state, -(nargs + 1)); + ffi::lua_pushlightuserdata(state, &mut func as *mut F as *mut c_void); + mem::forget(func); + if pcall_with_traceback(state, nargs + 1, nresults) != ffi::LUA_OK { + Err(pop_error(state)) + } else { + Ok(()) + } + } + + let mut res = None; + cpcall(state, nargs, nresults, |state| { + res = Some(callback_error(state, || func(state))); + ffi::lua_gettop(state) + })?; + Ok(res.unwrap()) } pub unsafe fn push_string(state: *mut ffi::lua_State, s: &str) { @@ -129,13 +177,13 @@ pub unsafe fn pop_error(state: *mut ffi::lua_State) -> LuaError { } else { ffi::lua_pop(state, 1); - LuaErrorKind::ScriptError("".to_owned()) - .into() + LuaErrorKind::ScriptError("".to_owned()).into() } } // ffi::lua_pcall with a message handler that gives a nice traceback. If the caught error is -// actually a LuaError, will simply pass the error along. +// actually a LuaError, will simply pass the error along. Does not call +// checkstack, and uses 2 extra stack spaces. pub unsafe fn pcall_with_traceback(state: *mut ffi::lua_State, nargs: c_int, nresults: c_int) @@ -145,7 +193,7 @@ pub unsafe fn pcall_with_traceback(state: *mut ffi::lua_State, if !is_panic_error(state, 1) { let error = pop_error(state); ffi::luaL_traceback(state, state, ptr::null(), 0); - let traceback = CStr::from_ptr(ffi::lua_tolstring(state, 1, ptr::null_mut())) + let traceback = CStr::from_ptr(ffi::lua_tolstring(state, -1, ptr::null_mut())) .to_str() .unwrap() .to_owned(); @@ -180,7 +228,7 @@ pub unsafe fn resume_with_traceback(state: *mut ffi::lua_State, if !is_panic_error(state, 1) { let error = pop_error(state); ffi::luaL_traceback(from, state, ptr::null(), 0); - let traceback = CStr::from_ptr(ffi::lua_tolstring(from, 1, ptr::null_mut())) + let traceback = CStr::from_ptr(ffi::lua_tolstring(from, -1, ptr::null_mut())) .to_str() .unwrap() .to_owned();