diff --git a/src/ffi.rs b/src/ffi.rs index 0bdf0d0..1d210e9 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -102,6 +102,7 @@ extern "C" { pub fn lua_newthread(state: *mut lua_State) -> *mut lua_State; pub fn lua_settable(state: *mut lua_State, index: c_int); + pub fn lua_rawset(state: *mut lua_State, index: c_int); pub fn lua_setmetatable(state: *mut lua_State, index: c_int); pub fn lua_len(state: *mut lua_State, index: c_int); @@ -109,6 +110,7 @@ extern "C" { pub fn lua_rawequal(state: *mut lua_State, index1: c_int, index2: c_int) -> c_int; pub fn lua_error(state: *mut lua_State) -> !; + pub fn lua_atpanic(state: *mut lua_State, panic: lua_CFunction) -> lua_CFunction; pub fn luaL_newstate() -> *mut lua_State; pub fn luaL_openlibs(state: *mut lua_State); @@ -125,6 +127,7 @@ extern "C" { state: *mut lua_State, msg: *const c_char, level: c_int); + pub fn luaL_len(push_state: *mut lua_State, index: c_int) -> lua_Integer; } pub unsafe fn lua_pop(state: *mut lua_State, n: c_int) { diff --git a/src/lua.rs b/src/lua.rs index 8ecfd20..cf668a8 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -158,7 +158,24 @@ impl<'lua> LuaTable<'lua> { lua.push_ref(lua.state, &self.0); lua.push_value(lua.state, key.to_lua(lua)?)?; lua.push_value(lua.state, value.to_lua(lua)?)?; - ffi::lua_settable(lua.state, -3); + error_guard(lua.state, 3, 0, |state| { + ffi::lua_settable(state, -3); + Ok(()) + })?; + Ok(()) + }) + } + } + + pub fn raw_set, V: ToLua<'lua>>(&self, key: K, value: V) -> LuaResult<()> { + let lua = self.0.lua; + unsafe { + stack_guard(lua.state, 0, || { + check_stack(lua.state, 3)?; + lua.push_ref(lua.state, &self.0); + lua.push_value(lua.state, key.to_lua(lua)?)?; + lua.push_value(lua.state, value.to_lua(lua)?)?; + ffi::lua_rawset(lua.state, -3); ffi::lua_pop(lua.state, 1); Ok(()) }) @@ -166,6 +183,24 @@ impl<'lua> LuaTable<'lua> { } pub fn get, V: FromLua<'lua>>(&self, key: K) -> LuaResult { + let lua = self.0.lua; + unsafe { + stack_guard(lua.state, 0, || { + check_stack(lua.state, 2)?; + lua.push_ref(lua.state, &self.0); + lua.push_value(lua.state, key.to_lua(lua)?)?; + error_guard(lua.state, 2, 2, |state| { + ffi::lua_gettable(state, -2); + Ok(()) + })?; + let res = V::from_lua(lua.pop_value(lua.state)?, lua)?; + ffi::lua_pop(lua.state, 1); + Ok(res) + }) + } + } + + pub fn raw_get, V: FromLua<'lua>>(&self, key: K) -> LuaResult { let lua = self.0.lua; unsafe { stack_guard(lua.state, 0, || { @@ -187,10 +222,7 @@ impl<'lua> LuaTable<'lua> { stack_guard(lua.state, 0, || { check_stack(lua.state, 1)?; lua.push_ref(lua.state, &self.0); - ffi::lua_len(lua.state, -1); - let len = ffi::lua_tointeger(lua.state, -1); - ffi::lua_pop(lua.state, 2); - Ok(len) + error_guard(lua.state, 1, 0, |state| Ok(ffi::luaL_len(state, -1))) }) } } @@ -232,10 +264,7 @@ impl<'lua> LuaTable<'lua> { check_stack(lua.state, 4)?; lua.push_ref(lua.state, &self.0); - ffi::lua_len(lua.state, -1); - let len = ffi::lua_tointeger(lua.state, -1); - ffi::lua_pop(lua.state, 1); - + let len = error_guard(lua.state, 1, 1, |state| Ok(ffi::luaL_len(state, -1)))?; ffi::lua_pushnil(lua.state); while ffi::lua_next(lua.state, -2) != 0 { @@ -290,7 +319,7 @@ impl<'lua> LuaFunction<'lua> { stack_guard(lua.state, 0, || { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; - check_stack(lua.state, nargs + 1)?; + check_stack(lua.state, nargs + 3)?; let stack_start = ffi::lua_gettop(lua.state); lua.push_ref(lua.state, &self.0); @@ -615,6 +644,15 @@ impl Lua { pub fn new() -> Lua { unsafe { let state = ffi::luaL_newstate(); + unsafe extern "C" fn panic_function(state: *mut ffi::lua_State) -> c_int { + if let Some(s) = ffi::lua_tostring(state, -1).as_ref() { + panic!("rlua - unprotected error in call to Lua API ({})", s) + } else { + panic!("rlua - unprotected error in call to Lua API ") + } + } + + ffi::lua_atpanic(state, panic_function); ffi::luaL_openlibs(state); stack_guard(state, 0, || { @@ -631,11 +669,11 @@ impl Lua { push_string(state, "__gc"); ffi::lua_pushcfunction(state, destructor::>>); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); ffi::lua_setmetatable(state, -2); - ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) .unwrap(); @@ -649,13 +687,13 @@ impl Lua { push_string(state, "__gc"); ffi::lua_pushcfunction(state, destructor::); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); push_string(state, "__metatable"); ffi::lua_pushboolean(state, 0); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); - ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) .unwrap(); @@ -665,11 +703,11 @@ impl Lua { push_string(state, "pcall"); ffi::lua_pushcfunction(state, safe_pcall); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); push_string(state, "xpcall"); ffi::lua_pushcfunction(state, safe_xpcall); - ffi::lua_settable(state, -3); + ffi::lua_rawset(state, -3); ffi::lua_pop(state, 1); Ok(()) @@ -680,7 +718,7 @@ impl Lua { ffi::lua_pushlightuserdata(state, &TOP_STATE_REGISTRY_KEY as *const u8 as *mut c_void); ffi::lua_pushlightuserdata(state, state as *mut c_void); - ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) .unwrap(); @@ -710,6 +748,7 @@ impl Lua { ptr::null()) })?; + check_stack(self.state, 2)?; handle_error(self.state, pcall_with_traceback(self.state, 0, 0)) }) } @@ -738,6 +777,7 @@ impl Lua { handle_error(self.state, res)?; + check_stack(self.state, 2)?; handle_error(self.state, pcall_with_traceback(self.state, 0, ffi::LUA_MULTRET))?; @@ -849,7 +889,7 @@ impl Lua { ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); self.push_value(self.state, key.to_lua(self)?)?; self.push_value(self.state, value.to_lua(self)?)?; - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); ffi::lua_pop(self.state, 1); Ok(()) }) @@ -1170,10 +1210,10 @@ impl Lua { push_string(self.state, &k); self.push_value(self.state, LuaValue::Function(self.create_callback_function(m)?))?; - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } check_stack(self.state, methods.meta_methods.len() as c_int * 2)?; @@ -1185,7 +1225,7 @@ impl Lua { self.push_value(self.state, LuaValue::Function(self.create_callback_function(m)?))?; ffi::lua_pushcclosure(self.state, meta_index_impl, 2); - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } else { let name = match k { LuaMetaMethod::Add => "__add", @@ -1207,17 +1247,17 @@ impl Lua { push_string(self.state, name); self.push_value(self.state, LuaValue::Function(self.create_callback_function(m)?))?; - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); } } push_string(self.state, "__gc"); ffi::lua_pushcfunction(self.state, destructor::>); - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); push_string(self.state, "__metatable"); ffi::lua_pushboolean(self.state, 0); - ffi::lua_settable(self.state, -3); + ffi::lua_rawset(self.state, -3); let id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX); entry.insert(id); @@ -1232,8 +1272,9 @@ static LUA_USERDATA_REGISTRY_KEY: u8 = 0; static FUNCTION_METATABLE_REGISTRY_KEY: u8 = 0; static TOP_STATE_REGISTRY_KEY: u8 = 0; -// If the return code is not LUA_OK, pops the error off of the stack and returns Err. If the error -// was actually a rust panic, clears the current lua stack and panics. +// If the return code indicates an error, pops the error off of the stack and +// returns Err. If the error was actually a rust panic, clears the current lua +// stack and panics. unsafe fn handle_error(state: *mut ffi::lua_State, ret: c_int) -> LuaResult<()> { if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { Err(pop_error(state)) diff --git a/src/tests.rs b/src/tests.rs index ca0c5ed..1832616 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -18,10 +18,12 @@ fn test_set_get() { #[test] fn test_load() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" res = 'foo'..'bar' "#, - None) + None, + ) .unwrap(); assert_eq!(lua.get::<_, String>("res").unwrap(), "foobar"); } @@ -52,12 +54,14 @@ fn test_table() { assert_eq!(table2.get::<_, String>("foo").unwrap(), "bar"); assert_eq!(table1.get::<_, String>("baz").unwrap(), "baf"); - lua.load(r#" + lua.load( + r#" table1 = {1, 2, 3, 4, 5} table2 = {} table3 = {1, 2, nil, 4, 5} "#, - None) + None, + ) .unwrap(); let table1 = lua.get::<_, LuaTable>("table1").unwrap(); @@ -78,12 +82,14 @@ fn test_table() { #[test] fn test_function() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function concat(arg1, arg2) return arg1 .. arg2 end "#, - None) + None, + ) .unwrap(); let concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -94,7 +100,8 @@ fn test_function() { #[test] fn test_bind() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function concat(...) local res = "" for _, s in pairs({...}) do @@ -103,7 +110,8 @@ fn test_bind() { return res end "#, - None) + None, + ) .unwrap(); let mut concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -117,7 +125,8 @@ fn test_bind() { #[test] fn test_rust_function() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function lua_function() return rust_function() end @@ -125,7 +134,8 @@ fn test_rust_function() { -- Test to make sure chunk return is ignored return 1 "#, - None) + None, + ) .unwrap(); let lua_function = lua.get::<_, LuaFunction>("lua_function").unwrap(); @@ -174,7 +184,8 @@ fn test_methods() { let lua = Lua::new(); let userdata = lua.create_userdata(UserData(42)).unwrap(); lua.set("userdata", userdata.clone()).unwrap(); - lua.load(r#" + lua.load( + r#" function get_it() return userdata:get_value() end @@ -183,7 +194,8 @@ fn test_methods() { return userdata:set_value(i) end "#, - None) + None, + ) .unwrap(); let get = lua.get::<_, LuaFunction>("get_it").unwrap(); let set = lua.get::<_, LuaFunction>("set_it").unwrap(); @@ -234,12 +246,14 @@ fn test_metamethods() { #[test] fn test_scope() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" touter = { tin = {1, 2, 3} } "#, - None) + None, + ) .unwrap(); // Make sure that table gets do not borrow the table, but instead just borrow lua. @@ -268,7 +282,8 @@ fn test_scope() { #[test] fn test_lua_multi() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function concat(arg1, arg2) return arg1 .. arg2 end @@ -277,7 +292,8 @@ fn test_lua_multi() { return 1, 2, 3, 4, 5, 6 end "#, - None) + None, + ) .unwrap(); let concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -295,12 +311,14 @@ fn test_lua_multi() { #[test] fn test_coercion() { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" int = 123 str = "123" num = 123.0 "#, - None) + None, + ) .unwrap(); assert_eq!(lua.get::<_, String>("int").unwrap(), "123"); @@ -330,7 +348,8 @@ fn test_error() { } let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function no_error() end @@ -368,7 +387,8 @@ fn test_error() { understand_recursion() end "#, - None) + None, + ) .unwrap(); let rust_error_function = @@ -400,12 +420,14 @@ fn test_error() { match catch_unwind(|| -> LuaResult<()> { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function rust_panic() pcall(function () rust_panic_function() end) end "#, - None)?; + None, + )?; let rust_panic_function = lua.create_function(|_, _| { panic!("expected panic, this panic should be caught in rust") })?; @@ -422,12 +444,14 @@ fn test_error() { match catch_unwind(|| -> LuaResult<()> { let lua = Lua::new(); - lua.load(r#" + lua.load( + r#" function rust_panic() xpcall(function() rust_panic_function() end, function() end) end "#, - None)?; + None, + )?; let rust_panic_function = lua.create_function(|_, _| { panic!("expected panic, this panic should be caught in rust") })?; @@ -497,10 +521,12 @@ fn test_thread() { #[test] fn test_lightuserdata() { let lua = Lua::new(); - lua.load(r#"function id(a) + lua.load( + r#"function id(a) return a end"#, - None) + None, + ) .unwrap(); let res = lua.get::<_, LuaFunction>("id") .unwrap() @@ -508,3 +534,33 @@ fn test_lightuserdata() { .unwrap(); assert_eq!(res, LightUserData(42 as *mut c_void)); } + +#[test] +fn test_table_error() { + let lua = Lua::new(); + lua.load( + r#" + table = {} + setmetatable(table, { + __index = function() + error("lua error") + end, + __newindex = function() + error("lua error") + end, + __len = function() + error("lua error") + end + }) + "#, + None, + ) + .unwrap(); + + let bad_table: LuaTable = lua.get("table").unwrap(); + assert!(bad_table.set("key", 1).is_err()); + assert!(bad_table.get::<_, i32>("key").is_err()); + assert!(bad_table.length().is_err()); + assert!(bad_table.raw_set("key", 1).is_ok()); + assert!(bad_table.raw_get::<_, i32>("key").is_ok()); +} diff --git a/src/util.rs b/src/util.rs index a58a699..2ba263c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -14,6 +14,14 @@ macro_rules! cstr { ); } +pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> LuaResult<()> { + if ffi::lua_checkstack(state, amount) == 0 { + Err("out of lua stack space".into()) + } else { + Ok(()) + } +} + // Run an operation on a lua_State and automatically clean up the stack before returning. Takes // the lua_State, the expected stack size change, and an operation to run. If the operation // results in success, then the stack is inspected to make sure the change in stack size matches @@ -46,12 +54,52 @@ pub unsafe fn stack_guard(state: *mut ffi::lua_State, change: c_int, op: F res } -pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> LuaResult<()> { - if ffi::lua_checkstack(state, amount) == 0 { - Err("out of lua stack space".into()) - } else { - Ok(()) +// Call the given rust function in a protected lua context, similar to pcall. +// The stack given to the protected function is a separate protected stack. This +// catches all calls to lua_error, but ffi functions that can call lua_error are +// still longjmps, and have all the same dangers as longjmps, so extreme care +// must still be taken in code that uses this function. Does not call +// lua_checkstack, and uses 2 extra stack spaces. +pub unsafe fn error_guard(state: *mut ffi::lua_State, + nargs: c_int, + nresults: c_int, + func: F) + -> LuaResult + where F: FnOnce(*mut ffi::lua_State) -> LuaResult + UnwindSafe +{ + unsafe extern "C" fn call_impl(state: *mut ffi::lua_State) -> c_int + where F: FnOnce(*mut ffi::lua_State) -> c_int + { + let func = ffi::lua_touserdata(state, -1) as *mut F; + let func = mem::replace(&mut *func, mem::uninitialized()); + ffi::lua_pop(state, 1); + func(state) } + + pub unsafe fn cpcall(state: *mut ffi::lua_State, + nargs: c_int, + nresults: c_int, + mut func: F) + -> LuaResult<()> + where F: FnOnce(*mut ffi::lua_State) -> c_int + { + ffi::lua_pushcfunction(state, call_impl::); + ffi::lua_insert(state, -(nargs + 1)); + ffi::lua_pushlightuserdata(state, &mut func as *mut F as *mut c_void); + mem::forget(func); + if pcall_with_traceback(state, nargs + 1, nresults) != ffi::LUA_OK { + Err(pop_error(state)) + } else { + Ok(()) + } + } + + let mut res = None; + cpcall(state, nargs, nresults, |state| { + res = Some(callback_error(state, || func(state))); + ffi::lua_gettop(state) + })?; + Ok(res.unwrap()) } pub unsafe fn push_string(state: *mut ffi::lua_State, s: &str) { @@ -129,13 +177,13 @@ pub unsafe fn pop_error(state: *mut ffi::lua_State) -> LuaError { } else { ffi::lua_pop(state, 1); - LuaErrorKind::ScriptError("".to_owned()) - .into() + LuaErrorKind::ScriptError("".to_owned()).into() } } // ffi::lua_pcall with a message handler that gives a nice traceback. If the caught error is -// actually a LuaError, will simply pass the error along. +// actually a LuaError, will simply pass the error along. Does not call +// checkstack, and uses 2 extra stack spaces. pub unsafe fn pcall_with_traceback(state: *mut ffi::lua_State, nargs: c_int, nresults: c_int) @@ -145,7 +193,7 @@ pub unsafe fn pcall_with_traceback(state: *mut ffi::lua_State, if !is_panic_error(state, 1) { let error = pop_error(state); ffi::luaL_traceback(state, state, ptr::null(), 0); - let traceback = CStr::from_ptr(ffi::lua_tolstring(state, 1, ptr::null_mut())) + let traceback = CStr::from_ptr(ffi::lua_tolstring(state, -1, ptr::null_mut())) .to_str() .unwrap() .to_owned(); @@ -180,7 +228,7 @@ pub unsafe fn resume_with_traceback(state: *mut ffi::lua_State, if !is_panic_error(state, 1) { let error = pop_error(state); ffi::luaL_traceback(from, state, ptr::null(), 0); - let traceback = CStr::from_ptr(ffi::lua_tolstring(from, 1, ptr::null_mut())) + let traceback = CStr::from_ptr(ffi::lua_tolstring(from, -1, ptr::null_mut())) .to_str() .unwrap() .to_owned();