diff --git a/src/lua.rs b/src/lua.rs index b8e0744..0522539 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -6,7 +6,6 @@ use std::ffi::CString; use std::any::TypeId; use std::marker::PhantomData; use std::collections::{HashMap, VecDeque}; -use std::collections::hash_map::Entry as HashMapEntry; use std::os::raw::{c_char, c_int, c_void}; use std::process; @@ -497,18 +496,12 @@ impl Lua { &LUA_USERDATA_REGISTRY_KEY as *const u8 as *mut c_void, ); - push_userdata::>>( - state, - RefCell::new(HashMap::new()), - ); + push_userdata::>(state, HashMap::new()); ffi::lua_newtable(state); push_string(state, "__gc"); - ffi::lua_pushcfunction( - state, - userdata_destructor::>>, - ); + ffi::lua_pushcfunction(state, userdata_destructor::>); ffi::lua_rawset(state, -3); ffi::lua_setmetatable(state, -2); @@ -525,7 +518,7 @@ impl Lua { ffi::lua_newtable(state); push_string(state, "__gc"); - ffi::lua_pushcfunction(state, userdata_destructor::); + ffi::lua_pushcfunction(state, userdata_destructor::>); ffi::lua_rawset(state, -3); push_string(state, "__metatable"); @@ -907,7 +900,15 @@ impl Lua { ephemeral: true, }; - let func = &mut *get_userdata::(state, ffi::lua_upvalueindex(1)); + let func = get_userdata::>(state, ffi::lua_upvalueindex(1)); + let mut func = if let Ok(func) = (*func).try_borrow_mut() { + func + } else { + lua_panic!( + state, + "recursive callback function call would mutably borrow function twice" + ); + }; let nargs = ffi::lua_gettop(state); let mut args = MultiValue::new(); @@ -915,7 +916,7 @@ impl Lua { args.push_front(lua.pop_value(state)); } - let results = func(&lua, args)?; + let results = func.deref_mut()(&lua, args)?; let nresults = results.len() as c_int; for r in results { @@ -930,7 +931,7 @@ impl Lua { stack_guard(self.state, 0, move || { check_stack(self.state, 2); - push_userdata::(self.state, func); + push_userdata::>(self.state, RefCell::new(func)); ffi::lua_pushlightuserdata( self.state, @@ -1103,100 +1104,97 @@ impl Lua { &LUA_USERDATA_REGISTRY_KEY as *const u8 as *mut c_void, ); ffi::lua_gettable(self.state, ffi::LUA_REGISTRYINDEX); - let registered_userdata = - &mut *get_userdata::>>(self.state, -1); - let mut map = (*registered_userdata).borrow_mut(); + let registered_userdata = get_userdata::>(self.state, -1); ffi::lua_pop(self.state, 1); - match map.entry(TypeId::of::()) { - HashMapEntry::Occupied(entry) => *entry.get(), - HashMapEntry::Vacant(entry) => { - ffi::lua_newtable(self.state); + if let Some(table_id) = (*registered_userdata).get(&TypeId::of::()) { + return *table_id; + } - let mut methods = UserDataMethods { - methods: HashMap::new(), - meta_methods: HashMap::new(), - _type: PhantomData, + let mut methods = UserDataMethods { + methods: HashMap::new(), + meta_methods: HashMap::new(), + _type: PhantomData, + }; + T::add_methods(&mut methods); + + ffi::lua_newtable(self.state); + + let has_methods = !methods.methods.is_empty(); + + if has_methods { + push_string(self.state, "__index"); + ffi::lua_newtable(self.state); + + for (k, m) in methods.methods { + push_string(self.state, &k); + self.push_value( + self.state, + Value::Function(self.create_callback_function(m)), + ); + ffi::lua_rawset(self.state, -3); + } + + ffi::lua_rawset(self.state, -3); + } + + for (k, m) in methods.meta_methods { + if k == MetaMethod::Index && has_methods { + push_string(self.state, "__index"); + ffi::lua_pushvalue(self.state, -1); + ffi::lua_gettable(self.state, -3); + self.push_value( + self.state, + Value::Function(self.create_callback_function(m)), + ); + ffi::lua_pushcclosure(self.state, meta_index_impl, 2); + ffi::lua_rawset(self.state, -3); + } else { + let name = match k { + MetaMethod::Add => "__add", + MetaMethod::Sub => "__sub", + MetaMethod::Mul => "__mul", + MetaMethod::Div => "__div", + MetaMethod::Mod => "__mod", + MetaMethod::Pow => "__pow", + MetaMethod::Unm => "__unm", + MetaMethod::IDiv => "__idiv", + MetaMethod::BAnd => "__band", + MetaMethod::BOr => "__bor", + MetaMethod::BXor => "__bxor", + MetaMethod::BNot => "__bnot", + MetaMethod::Shl => "__shl", + MetaMethod::Shr => "__shr", + MetaMethod::Concat => "__concat", + MetaMethod::Len => "__len", + MetaMethod::Eq => "__eq", + MetaMethod::Lt => "__lt", + MetaMethod::Le => "__le", + MetaMethod::Index => "__index", + MetaMethod::NewIndex => "__newindex", + MetaMethod::Call => "__call", + MetaMethod::ToString => "__tostring", }; - T::add_methods(&mut methods); - - let has_methods = !methods.methods.is_empty(); - - if has_methods { - push_string(self.state, "__index"); - ffi::lua_newtable(self.state); - - for (k, m) in methods.methods { - push_string(self.state, &k); - self.push_value( - self.state, - Value::Function(self.create_callback_function(m)), - ); - ffi::lua_rawset(self.state, -3); - } - - ffi::lua_rawset(self.state, -3); - } - - for (k, m) in methods.meta_methods { - if k == MetaMethod::Index && has_methods { - push_string(self.state, "__index"); - ffi::lua_pushvalue(self.state, -1); - ffi::lua_gettable(self.state, -3); - self.push_value( - self.state, - Value::Function(self.create_callback_function(m)), - ); - ffi::lua_pushcclosure(self.state, meta_index_impl, 2); - ffi::lua_rawset(self.state, -3); - } else { - let name = match k { - MetaMethod::Add => "__add", - MetaMethod::Sub => "__sub", - MetaMethod::Mul => "__mul", - MetaMethod::Div => "__div", - MetaMethod::Mod => "__mod", - MetaMethod::Pow => "__pow", - MetaMethod::Unm => "__unm", - MetaMethod::IDiv => "__idiv", - MetaMethod::BAnd => "__band", - MetaMethod::BOr => "__bor", - MetaMethod::BXor => "__bxor", - MetaMethod::BNot => "__bnot", - MetaMethod::Shl => "__shl", - MetaMethod::Shr => "__shr", - MetaMethod::Concat => "__concat", - MetaMethod::Len => "__len", - MetaMethod::Eq => "__eq", - MetaMethod::Lt => "__lt", - MetaMethod::Le => "__le", - MetaMethod::Index => "__index", - MetaMethod::NewIndex => "__newindex", - MetaMethod::Call => "__call", - MetaMethod::ToString => "__tostring", - }; - push_string(self.state, name); - self.push_value( - self.state, - Value::Function(self.create_callback_function(m)), - ); - ffi::lua_rawset(self.state, -3); - } - } - - push_string(self.state, "__gc"); - ffi::lua_pushcfunction(self.state, userdata_destructor::>); + push_string(self.state, name); + self.push_value( + self.state, + Value::Function(self.create_callback_function(m)), + ); ffi::lua_rawset(self.state, -3); - - push_string(self.state, "__metatable"); - ffi::lua_pushboolean(self.state, 0); - ffi::lua_rawset(self.state, -3); - - let id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX); - entry.insert(id); - id } } + + push_string(self.state, "__gc"); + ffi::lua_pushcfunction(self.state, userdata_destructor::>); + ffi::lua_rawset(self.state, -3); + + push_string(self.state, "__metatable"); + ffi::lua_pushboolean(self.state, 0); + ffi::lua_rawset(self.state, -3); + + let id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX); + (*registered_userdata).insert(TypeId::of::(), id); + id }) } } diff --git a/src/tests.rs b/src/tests.rs index 20a8d20..6e7ce53 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -560,6 +560,39 @@ fn test_pcall_xpcall() { .call::<_, ()>(()); } +#[test] +#[should_panic] +fn test_recursive_callback_panic() { + let lua = Lua::new(); + + let mut v = Some(Box::new(123)); + let f = lua.create_function::<_, (), _>(move |lua, mutate: bool| { + if mutate { + v = None; + } else { + // Produce a mutable reference + let r = v.as_mut().unwrap(); + // Whoops, this will recurse into the function and produce another mutable reference! + lua.globals() + .get::<_, Function>("f") + .unwrap() + .call::<_, ()>(true) + .unwrap(); + println!("Should not get here, mutable aliasing has occurred!"); + println!("value at {:p}", r as *mut _); + println!("value is {}", r); + } + + Ok(()) + }); + lua.globals().set("f", f).unwrap(); + lua.globals() + .get::<_, Function>("f") + .unwrap() + .call::<_, ()>(false) + .unwrap(); +} + // TODO: Need to use compiletest-rs or similar to make sure these don't compile. /* #[test] diff --git a/src/userdata.rs b/src/userdata.rs index 3aee9bc..ebef232 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -536,18 +536,18 @@ mod tests { lua.eval::<()>( r#" - local tbl = setmetatable({ - userdata = userdata - }, { __gc = function(self) - -- resurrect userdata - hatch = self.userdata - end }) + local tbl = setmetatable({ + userdata = userdata + }, { __gc = function(self) + -- resurrect userdata + hatch = self.userdata + end }) - tbl = nil - userdata = nil -- make table and userdata collectable - collectgarbage("collect") - hatch:access() - "#, + tbl = nil + userdata = nil -- make table and userdata collectable + collectgarbage("collect") + hatch:access() + "#, None, ).unwrap(); } diff --git a/src/util.rs b/src/util.rs index bce5166..8d03036 100644 --- a/src/util.rs +++ b/src/util.rs @@ -257,8 +257,8 @@ pub unsafe fn handle_error(state: *mut ffi::lua_State, err: c_int) -> Result<()> Err(err) } else if is_wrapped_panic(state, -1) { - let panic = &mut *get_userdata::(state, -1); - if let Some(p) = panic.0.take() { + let panic = get_userdata::(state, -1); + if let Some(p) = (*panic).0.take() { ffi::lua_settop(state, 0); resume_unwind(p); } else { @@ -317,17 +317,16 @@ pub unsafe fn push_string(state: *mut ffi::lua_State, s: &str) { ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len()); } -pub unsafe fn push_userdata(state: *mut ffi::lua_State, t: T) -> *mut T { +pub unsafe fn push_userdata(state: *mut ffi::lua_State, t: T) { let ud = ffi::lua_newuserdata(state, mem::size_of::>()) as *mut Option; - ptr::write(ud, None); - *ud = Some(t); - (*ud).as_mut().unwrap() + ptr::write(ud, Some(t)); } pub unsafe fn get_userdata(state: *mut ffi::lua_State, index: c_int) -> *mut T { let ud = ffi::lua_touserdata(state, index) as *mut Option; lua_assert!(state, !ud.is_null()); - (*ud).as_mut().expect("access of expired userdata") + lua_assert!(state, (*ud).is_some(), "access of expired userdata"); + (*ud).as_mut().unwrap() } pub unsafe extern "C" fn userdata_destructor(state: *mut ffi::lua_State) -> c_int { @@ -552,8 +551,8 @@ pub struct WrappedPanic(pub Option>); pub unsafe fn push_wrapped_error(state: *mut ffi::lua_State, err: Error) { unsafe extern "C" fn error_tostring(state: *mut ffi::lua_State) -> c_int { callback_error(state, || if is_wrapped_error(state, -1) { - let error = &*get_userdata::(state, -1); - push_string(state, &error.0.to_string()); + let error = get_userdata::(state, -1); + push_string(state, &(*error).0.to_string()); ffi::lua_remove(state, -2); Ok(1)