Handle unprotected lua errors SOMEWHAT more elegantly

There should be drastically less ways to cause unprotected lua errors now, as
the LuaTable functions which were trivial to cause unprotected errors are now
protected. Unfortunately, they are protected in a pretty slow, terrible way
right now, but it at least works.

Also, set the atpanic function in lua to call a proper rust panic instead.
This commit is contained in:
kyren 2017-06-05 00:03:39 -04:00
parent bd7ee783e3
commit 47d4ea62ff
4 changed files with 211 additions and 63 deletions

View File

@ -102,6 +102,7 @@ extern "C" {
pub fn lua_newthread(state: *mut lua_State) -> *mut lua_State;
pub fn lua_settable(state: *mut lua_State, index: c_int);
pub fn lua_rawset(state: *mut lua_State, index: c_int);
pub fn lua_setmetatable(state: *mut lua_State, index: c_int);
pub fn lua_len(state: *mut lua_State, index: c_int);
@ -109,6 +110,7 @@ extern "C" {
pub fn lua_rawequal(state: *mut lua_State, index1: c_int, index2: c_int) -> c_int;
pub fn lua_error(state: *mut lua_State) -> !;
pub fn lua_atpanic(state: *mut lua_State, panic: lua_CFunction) -> lua_CFunction;
pub fn luaL_newstate() -> *mut lua_State;
pub fn luaL_openlibs(state: *mut lua_State);
@ -125,6 +127,7 @@ extern "C" {
state: *mut lua_State,
msg: *const c_char,
level: c_int);
pub fn luaL_len(push_state: *mut lua_State, index: c_int) -> lua_Integer;
}
pub unsafe fn lua_pop(state: *mut lua_State, n: c_int) {

View File

@ -158,7 +158,24 @@ impl<'lua> LuaTable<'lua> {
lua.push_ref(lua.state, &self.0);
lua.push_value(lua.state, key.to_lua(lua)?)?;
lua.push_value(lua.state, value.to_lua(lua)?)?;
ffi::lua_settable(lua.state, -3);
error_guard(lua.state, 3, 0, |state| {
ffi::lua_settable(state, -3);
Ok(())
})?;
Ok(())
})
}
}
pub fn raw_set<K: ToLua<'lua>, V: ToLua<'lua>>(&self, key: K, value: V) -> LuaResult<()> {
let lua = self.0.lua;
unsafe {
stack_guard(lua.state, 0, || {
check_stack(lua.state, 3)?;
lua.push_ref(lua.state, &self.0);
lua.push_value(lua.state, key.to_lua(lua)?)?;
lua.push_value(lua.state, value.to_lua(lua)?)?;
ffi::lua_rawset(lua.state, -3);
ffi::lua_pop(lua.state, 1);
Ok(())
})
@ -166,6 +183,24 @@ impl<'lua> LuaTable<'lua> {
}
pub fn get<K: ToLua<'lua>, V: FromLua<'lua>>(&self, key: K) -> LuaResult<V> {
let lua = self.0.lua;
unsafe {
stack_guard(lua.state, 0, || {
check_stack(lua.state, 2)?;
lua.push_ref(lua.state, &self.0);
lua.push_value(lua.state, key.to_lua(lua)?)?;
error_guard(lua.state, 2, 2, |state| {
ffi::lua_gettable(state, -2);
Ok(())
})?;
let res = V::from_lua(lua.pop_value(lua.state)?, lua)?;
ffi::lua_pop(lua.state, 1);
Ok(res)
})
}
}
pub fn raw_get<K: ToLua<'lua>, V: FromLua<'lua>>(&self, key: K) -> LuaResult<V> {
let lua = self.0.lua;
unsafe {
stack_guard(lua.state, 0, || {
@ -187,10 +222,7 @@ impl<'lua> LuaTable<'lua> {
stack_guard(lua.state, 0, || {
check_stack(lua.state, 1)?;
lua.push_ref(lua.state, &self.0);
ffi::lua_len(lua.state, -1);
let len = ffi::lua_tointeger(lua.state, -1);
ffi::lua_pop(lua.state, 2);
Ok(len)
error_guard(lua.state, 1, 0, |state| Ok(ffi::luaL_len(state, -1)))
})
}
}
@ -232,10 +264,7 @@ impl<'lua> LuaTable<'lua> {
check_stack(lua.state, 4)?;
lua.push_ref(lua.state, &self.0);
ffi::lua_len(lua.state, -1);
let len = ffi::lua_tointeger(lua.state, -1);
ffi::lua_pop(lua.state, 1);
let len = error_guard(lua.state, 1, 1, |state| Ok(ffi::luaL_len(state, -1)))?;
ffi::lua_pushnil(lua.state);
while ffi::lua_next(lua.state, -2) != 0 {
@ -290,7 +319,7 @@ impl<'lua> LuaFunction<'lua> {
stack_guard(lua.state, 0, || {
let args = args.to_lua_multi(lua)?;
let nargs = args.len() as c_int;
check_stack(lua.state, nargs + 1)?;
check_stack(lua.state, nargs + 3)?;
let stack_start = ffi::lua_gettop(lua.state);
lua.push_ref(lua.state, &self.0);
@ -615,6 +644,15 @@ impl Lua {
pub fn new() -> Lua {
unsafe {
let state = ffi::luaL_newstate();
unsafe extern "C" fn panic_function(state: *mut ffi::lua_State) -> c_int {
if let Some(s) = ffi::lua_tostring(state, -1).as_ref() {
panic!("rlua - unprotected error in call to Lua API ({})", s)
} else {
panic!("rlua - unprotected error in call to Lua API <unprintable error>")
}
}
ffi::lua_atpanic(state, panic_function);
ffi::luaL_openlibs(state);
stack_guard(state, 0, || {
@ -631,11 +669,11 @@ impl Lua {
push_string(state, "__gc");
ffi::lua_pushcfunction(state, destructor::<RefCell<HashMap<TypeId, c_int>>>);
ffi::lua_settable(state, -3);
ffi::lua_rawset(state, -3);
ffi::lua_setmetatable(state, -2);
ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX);
ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX);
Ok(())
})
.unwrap();
@ -649,13 +687,13 @@ impl Lua {
push_string(state, "__gc");
ffi::lua_pushcfunction(state, destructor::<LuaCallback>);
ffi::lua_settable(state, -3);
ffi::lua_rawset(state, -3);
push_string(state, "__metatable");
ffi::lua_pushboolean(state, 0);
ffi::lua_settable(state, -3);
ffi::lua_rawset(state, -3);
ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX);
ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX);
Ok(())
})
.unwrap();
@ -665,11 +703,11 @@ impl Lua {
push_string(state, "pcall");
ffi::lua_pushcfunction(state, safe_pcall);
ffi::lua_settable(state, -3);
ffi::lua_rawset(state, -3);
push_string(state, "xpcall");
ffi::lua_pushcfunction(state, safe_xpcall);
ffi::lua_settable(state, -3);
ffi::lua_rawset(state, -3);
ffi::lua_pop(state, 1);
Ok(())
@ -680,7 +718,7 @@ impl Lua {
ffi::lua_pushlightuserdata(state,
&TOP_STATE_REGISTRY_KEY as *const u8 as *mut c_void);
ffi::lua_pushlightuserdata(state, state as *mut c_void);
ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX);
ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX);
Ok(())
})
.unwrap();
@ -710,6 +748,7 @@ impl Lua {
ptr::null())
})?;
check_stack(self.state, 2)?;
handle_error(self.state, pcall_with_traceback(self.state, 0, 0))
})
}
@ -738,6 +777,7 @@ impl Lua {
handle_error(self.state, res)?;
check_stack(self.state, 2)?;
handle_error(self.state,
pcall_with_traceback(self.state, 0, ffi::LUA_MULTRET))?;
@ -849,7 +889,7 @@ impl Lua {
ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS);
self.push_value(self.state, key.to_lua(self)?)?;
self.push_value(self.state, value.to_lua(self)?)?;
ffi::lua_settable(self.state, -3);
ffi::lua_rawset(self.state, -3);
ffi::lua_pop(self.state, 1);
Ok(())
})
@ -1170,10 +1210,10 @@ impl Lua {
push_string(self.state, &k);
self.push_value(self.state,
LuaValue::Function(self.create_callback_function(m)?))?;
ffi::lua_settable(self.state, -3);
ffi::lua_rawset(self.state, -3);
}
ffi::lua_settable(self.state, -3);
ffi::lua_rawset(self.state, -3);
}
check_stack(self.state, methods.meta_methods.len() as c_int * 2)?;
@ -1185,7 +1225,7 @@ impl Lua {
self.push_value(self.state,
LuaValue::Function(self.create_callback_function(m)?))?;
ffi::lua_pushcclosure(self.state, meta_index_impl, 2);
ffi::lua_settable(self.state, -3);
ffi::lua_rawset(self.state, -3);
} else {
let name = match k {
LuaMetaMethod::Add => "__add",
@ -1207,17 +1247,17 @@ impl Lua {
push_string(self.state, name);
self.push_value(self.state,
LuaValue::Function(self.create_callback_function(m)?))?;
ffi::lua_settable(self.state, -3);
ffi::lua_rawset(self.state, -3);
}
}
push_string(self.state, "__gc");
ffi::lua_pushcfunction(self.state, destructor::<RefCell<T>>);
ffi::lua_settable(self.state, -3);
ffi::lua_rawset(self.state, -3);
push_string(self.state, "__metatable");
ffi::lua_pushboolean(self.state, 0);
ffi::lua_settable(self.state, -3);
ffi::lua_rawset(self.state, -3);
let id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX);
entry.insert(id);
@ -1232,8 +1272,9 @@ static LUA_USERDATA_REGISTRY_KEY: u8 = 0;
static FUNCTION_METATABLE_REGISTRY_KEY: u8 = 0;
static TOP_STATE_REGISTRY_KEY: u8 = 0;
// If the return code is not LUA_OK, pops the error off of the stack and returns Err. If the error
// was actually a rust panic, clears the current lua stack and panics.
// If the return code indicates an error, pops the error off of the stack and
// returns Err. If the error was actually a rust panic, clears the current lua
// stack and panics.
unsafe fn handle_error(state: *mut ffi::lua_State, ret: c_int) -> LuaResult<()> {
if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
Err(pop_error(state))

View File

@ -18,10 +18,12 @@ fn test_set_get() {
#[test]
fn test_load() {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
res = 'foo'..'bar'
"#,
None)
None,
)
.unwrap();
assert_eq!(lua.get::<_, String>("res").unwrap(), "foobar");
}
@ -52,12 +54,14 @@ fn test_table() {
assert_eq!(table2.get::<_, String>("foo").unwrap(), "bar");
assert_eq!(table1.get::<_, String>("baz").unwrap(), "baf");
lua.load(r#"
lua.load(
r#"
table1 = {1, 2, 3, 4, 5}
table2 = {}
table3 = {1, 2, nil, 4, 5}
"#,
None)
None,
)
.unwrap();
let table1 = lua.get::<_, LuaTable>("table1").unwrap();
@ -78,12 +82,14 @@ fn test_table() {
#[test]
fn test_function() {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
function concat(arg1, arg2)
return arg1 .. arg2
end
"#,
None)
None,
)
.unwrap();
let concat = lua.get::<_, LuaFunction>("concat").unwrap();
@ -94,7 +100,8 @@ fn test_function() {
#[test]
fn test_bind() {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
function concat(...)
local res = ""
for _, s in pairs({...}) do
@ -103,7 +110,8 @@ fn test_bind() {
return res
end
"#,
None)
None,
)
.unwrap();
let mut concat = lua.get::<_, LuaFunction>("concat").unwrap();
@ -117,7 +125,8 @@ fn test_bind() {
#[test]
fn test_rust_function() {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
function lua_function()
return rust_function()
end
@ -125,7 +134,8 @@ fn test_rust_function() {
-- Test to make sure chunk return is ignored
return 1
"#,
None)
None,
)
.unwrap();
let lua_function = lua.get::<_, LuaFunction>("lua_function").unwrap();
@ -174,7 +184,8 @@ fn test_methods() {
let lua = Lua::new();
let userdata = lua.create_userdata(UserData(42)).unwrap();
lua.set("userdata", userdata.clone()).unwrap();
lua.load(r#"
lua.load(
r#"
function get_it()
return userdata:get_value()
end
@ -183,7 +194,8 @@ fn test_methods() {
return userdata:set_value(i)
end
"#,
None)
None,
)
.unwrap();
let get = lua.get::<_, LuaFunction>("get_it").unwrap();
let set = lua.get::<_, LuaFunction>("set_it").unwrap();
@ -234,12 +246,14 @@ fn test_metamethods() {
#[test]
fn test_scope() {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
touter = {
tin = {1, 2, 3}
}
"#,
None)
None,
)
.unwrap();
// Make sure that table gets do not borrow the table, but instead just borrow lua.
@ -268,7 +282,8 @@ fn test_scope() {
#[test]
fn test_lua_multi() {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
function concat(arg1, arg2)
return arg1 .. arg2
end
@ -277,7 +292,8 @@ fn test_lua_multi() {
return 1, 2, 3, 4, 5, 6
end
"#,
None)
None,
)
.unwrap();
let concat = lua.get::<_, LuaFunction>("concat").unwrap();
@ -295,12 +311,14 @@ fn test_lua_multi() {
#[test]
fn test_coercion() {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
int = 123
str = "123"
num = 123.0
"#,
None)
None,
)
.unwrap();
assert_eq!(lua.get::<_, String>("int").unwrap(), "123");
@ -330,7 +348,8 @@ fn test_error() {
}
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
function no_error()
end
@ -368,7 +387,8 @@ fn test_error() {
understand_recursion()
end
"#,
None)
None,
)
.unwrap();
let rust_error_function =
@ -400,12 +420,14 @@ fn test_error() {
match catch_unwind(|| -> LuaResult<()> {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
function rust_panic()
pcall(function () rust_panic_function() end)
end
"#,
None)?;
None,
)?;
let rust_panic_function = lua.create_function(|_, _| {
panic!("expected panic, this panic should be caught in rust")
})?;
@ -422,12 +444,14 @@ fn test_error() {
match catch_unwind(|| -> LuaResult<()> {
let lua = Lua::new();
lua.load(r#"
lua.load(
r#"
function rust_panic()
xpcall(function() rust_panic_function() end, function() end)
end
"#,
None)?;
None,
)?;
let rust_panic_function = lua.create_function(|_, _| {
panic!("expected panic, this panic should be caught in rust")
})?;
@ -497,10 +521,12 @@ fn test_thread() {
#[test]
fn test_lightuserdata() {
let lua = Lua::new();
lua.load(r#"function id(a)
lua.load(
r#"function id(a)
return a
end"#,
None)
None,
)
.unwrap();
let res = lua.get::<_, LuaFunction>("id")
.unwrap()
@ -508,3 +534,33 @@ fn test_lightuserdata() {
.unwrap();
assert_eq!(res, LightUserData(42 as *mut c_void));
}
#[test]
fn test_table_error() {
let lua = Lua::new();
lua.load(
r#"
table = {}
setmetatable(table, {
__index = function()
error("lua error")
end,
__newindex = function()
error("lua error")
end,
__len = function()
error("lua error")
end
})
"#,
None,
)
.unwrap();
let bad_table: LuaTable = lua.get("table").unwrap();
assert!(bad_table.set("key", 1).is_err());
assert!(bad_table.get::<_, i32>("key").is_err());
assert!(bad_table.length().is_err());
assert!(bad_table.raw_set("key", 1).is_ok());
assert!(bad_table.raw_get::<_, i32>("key").is_ok());
}

View File

@ -14,6 +14,14 @@ macro_rules! cstr {
);
}
pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> LuaResult<()> {
if ffi::lua_checkstack(state, amount) == 0 {
Err("out of lua stack space".into())
} else {
Ok(())
}
}
// Run an operation on a lua_State and automatically clean up the stack before returning. Takes
// the lua_State, the expected stack size change, and an operation to run. If the operation
// results in success, then the stack is inspected to make sure the change in stack size matches
@ -46,14 +54,54 @@ pub unsafe fn stack_guard<F, R>(state: *mut ffi::lua_State, change: c_int, op: F
res
}
pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> LuaResult<()> {
if ffi::lua_checkstack(state, amount) == 0 {
Err("out of lua stack space".into())
// Call the given rust function in a protected lua context, similar to pcall.
// The stack given to the protected function is a separate protected stack. This
// catches all calls to lua_error, but ffi functions that can call lua_error are
// still longjmps, and have all the same dangers as longjmps, so extreme care
// must still be taken in code that uses this function. Does not call
// lua_checkstack, and uses 2 extra stack spaces.
pub unsafe fn error_guard<F, R>(state: *mut ffi::lua_State,
nargs: c_int,
nresults: c_int,
func: F)
-> LuaResult<R>
where F: FnOnce(*mut ffi::lua_State) -> LuaResult<R> + UnwindSafe
{
unsafe extern "C" fn call_impl<F>(state: *mut ffi::lua_State) -> c_int
where F: FnOnce(*mut ffi::lua_State) -> c_int
{
let func = ffi::lua_touserdata(state, -1) as *mut F;
let func = mem::replace(&mut *func, mem::uninitialized());
ffi::lua_pop(state, 1);
func(state)
}
pub unsafe fn cpcall<F>(state: *mut ffi::lua_State,
nargs: c_int,
nresults: c_int,
mut func: F)
-> LuaResult<()>
where F: FnOnce(*mut ffi::lua_State) -> c_int
{
ffi::lua_pushcfunction(state, call_impl::<F>);
ffi::lua_insert(state, -(nargs + 1));
ffi::lua_pushlightuserdata(state, &mut func as *mut F as *mut c_void);
mem::forget(func);
if pcall_with_traceback(state, nargs + 1, nresults) != ffi::LUA_OK {
Err(pop_error(state))
} else {
Ok(())
}
}
let mut res = None;
cpcall(state, nargs, nresults, |state| {
res = Some(callback_error(state, || func(state)));
ffi::lua_gettop(state)
})?;
Ok(res.unwrap())
}
pub unsafe fn push_string(state: *mut ffi::lua_State, s: &str) {
ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len());
}
@ -129,13 +177,13 @@ pub unsafe fn pop_error(state: *mut ffi::lua_State) -> LuaError {
} else {
ffi::lua_pop(state, 1);
LuaErrorKind::ScriptError("<unprintable error>".to_owned())
.into()
LuaErrorKind::ScriptError("<unprintable error>".to_owned()).into()
}
}
// ffi::lua_pcall with a message handler that gives a nice traceback. If the caught error is
// actually a LuaError, will simply pass the error along.
// actually a LuaError, will simply pass the error along. Does not call
// checkstack, and uses 2 extra stack spaces.
pub unsafe fn pcall_with_traceback(state: *mut ffi::lua_State,
nargs: c_int,
nresults: c_int)
@ -145,7 +193,7 @@ pub unsafe fn pcall_with_traceback(state: *mut ffi::lua_State,
if !is_panic_error(state, 1) {
let error = pop_error(state);
ffi::luaL_traceback(state, state, ptr::null(), 0);
let traceback = CStr::from_ptr(ffi::lua_tolstring(state, 1, ptr::null_mut()))
let traceback = CStr::from_ptr(ffi::lua_tolstring(state, -1, ptr::null_mut()))
.to_str()
.unwrap()
.to_owned();
@ -180,7 +228,7 @@ pub unsafe fn resume_with_traceback(state: *mut ffi::lua_State,
if !is_panic_error(state, 1) {
let error = pop_error(state);
ffi::luaL_traceback(from, state, ptr::null(), 0);
let traceback = CStr::from_ptr(ffi::lua_tolstring(from, 1, ptr::null_mut()))
let traceback = CStr::from_ptr(ffi::lua_tolstring(from, -1, ptr::null_mut()))
.to_str()
.unwrap()
.to_owned();