Cleanups of userdata handling, particularly around callbacks

First, make sure that `add_methods` cannot trigger another userdata registry
insert, causing an unintended panic.  Second, remove `RefCell` surrounding
userdata hashmap, as this change makes it no longer needed.  Third, add a
`RefCell` around `Callback` because FnMut means that callbacks cannot recurse
into themselves, and panic appropriately when this happens.  This should
eventually be turned into an error.
This commit is contained in:
kyren 2017-10-14 18:26:09 -04:00
parent c5a4dfd7eb
commit 4b7a3403bc
4 changed files with 150 additions and 120 deletions

View File

@ -6,7 +6,6 @@ use std::ffi::CString;
use std::any::TypeId;
use std::marker::PhantomData;
use std::collections::{HashMap, VecDeque};
use std::collections::hash_map::Entry as HashMapEntry;
use std::os::raw::{c_char, c_int, c_void};
use std::process;
@ -497,18 +496,12 @@ impl Lua {
&LUA_USERDATA_REGISTRY_KEY as *const u8 as *mut c_void,
);
push_userdata::<RefCell<HashMap<TypeId, c_int>>>(
state,
RefCell::new(HashMap::new()),
);
push_userdata::<HashMap<TypeId, c_int>>(state, HashMap::new());
ffi::lua_newtable(state);
push_string(state, "__gc");
ffi::lua_pushcfunction(
state,
userdata_destructor::<RefCell<HashMap<TypeId, c_int>>>,
);
ffi::lua_pushcfunction(state, userdata_destructor::<HashMap<TypeId, c_int>>);
ffi::lua_rawset(state, -3);
ffi::lua_setmetatable(state, -2);
@ -525,7 +518,7 @@ impl Lua {
ffi::lua_newtable(state);
push_string(state, "__gc");
ffi::lua_pushcfunction(state, userdata_destructor::<Callback>);
ffi::lua_pushcfunction(state, userdata_destructor::<RefCell<Callback>>);
ffi::lua_rawset(state, -3);
push_string(state, "__metatable");
@ -907,7 +900,15 @@ impl Lua {
ephemeral: true,
};
let func = &mut *get_userdata::<Callback>(state, ffi::lua_upvalueindex(1));
let func = get_userdata::<RefCell<Callback>>(state, ffi::lua_upvalueindex(1));
let mut func = if let Ok(func) = (*func).try_borrow_mut() {
func
} else {
lua_panic!(
state,
"recursive callback function call would mutably borrow function twice"
);
};
let nargs = ffi::lua_gettop(state);
let mut args = MultiValue::new();
@ -915,7 +916,7 @@ impl Lua {
args.push_front(lua.pop_value(state));
}
let results = func(&lua, args)?;
let results = func.deref_mut()(&lua, args)?;
let nresults = results.len() as c_int;
for r in results {
@ -930,7 +931,7 @@ impl Lua {
stack_guard(self.state, 0, move || {
check_stack(self.state, 2);
push_userdata::<Callback>(self.state, func);
push_userdata::<RefCell<Callback>>(self.state, RefCell::new(func));
ffi::lua_pushlightuserdata(
self.state,
@ -1103,100 +1104,97 @@ impl Lua {
&LUA_USERDATA_REGISTRY_KEY as *const u8 as *mut c_void,
);
ffi::lua_gettable(self.state, ffi::LUA_REGISTRYINDEX);
let registered_userdata =
&mut *get_userdata::<RefCell<HashMap<TypeId, c_int>>>(self.state, -1);
let mut map = (*registered_userdata).borrow_mut();
let registered_userdata = get_userdata::<HashMap<TypeId, c_int>>(self.state, -1);
ffi::lua_pop(self.state, 1);
match map.entry(TypeId::of::<T>()) {
HashMapEntry::Occupied(entry) => *entry.get(),
HashMapEntry::Vacant(entry) => {
ffi::lua_newtable(self.state);
if let Some(table_id) = (*registered_userdata).get(&TypeId::of::<T>()) {
return *table_id;
}
let mut methods = UserDataMethods {
methods: HashMap::new(),
meta_methods: HashMap::new(),
_type: PhantomData,
let mut methods = UserDataMethods {
methods: HashMap::new(),
meta_methods: HashMap::new(),
_type: PhantomData,
};
T::add_methods(&mut methods);
ffi::lua_newtable(self.state);
let has_methods = !methods.methods.is_empty();
if has_methods {
push_string(self.state, "__index");
ffi::lua_newtable(self.state);
for (k, m) in methods.methods {
push_string(self.state, &k);
self.push_value(
self.state,
Value::Function(self.create_callback_function(m)),
);
ffi::lua_rawset(self.state, -3);
}
ffi::lua_rawset(self.state, -3);
}
for (k, m) in methods.meta_methods {
if k == MetaMethod::Index && has_methods {
push_string(self.state, "__index");
ffi::lua_pushvalue(self.state, -1);
ffi::lua_gettable(self.state, -3);
self.push_value(
self.state,
Value::Function(self.create_callback_function(m)),
);
ffi::lua_pushcclosure(self.state, meta_index_impl, 2);
ffi::lua_rawset(self.state, -3);
} else {
let name = match k {
MetaMethod::Add => "__add",
MetaMethod::Sub => "__sub",
MetaMethod::Mul => "__mul",
MetaMethod::Div => "__div",
MetaMethod::Mod => "__mod",
MetaMethod::Pow => "__pow",
MetaMethod::Unm => "__unm",
MetaMethod::IDiv => "__idiv",
MetaMethod::BAnd => "__band",
MetaMethod::BOr => "__bor",
MetaMethod::BXor => "__bxor",
MetaMethod::BNot => "__bnot",
MetaMethod::Shl => "__shl",
MetaMethod::Shr => "__shr",
MetaMethod::Concat => "__concat",
MetaMethod::Len => "__len",
MetaMethod::Eq => "__eq",
MetaMethod::Lt => "__lt",
MetaMethod::Le => "__le",
MetaMethod::Index => "__index",
MetaMethod::NewIndex => "__newindex",
MetaMethod::Call => "__call",
MetaMethod::ToString => "__tostring",
};
T::add_methods(&mut methods);
let has_methods = !methods.methods.is_empty();
if has_methods {
push_string(self.state, "__index");
ffi::lua_newtable(self.state);
for (k, m) in methods.methods {
push_string(self.state, &k);
self.push_value(
self.state,
Value::Function(self.create_callback_function(m)),
);
ffi::lua_rawset(self.state, -3);
}
ffi::lua_rawset(self.state, -3);
}
for (k, m) in methods.meta_methods {
if k == MetaMethod::Index && has_methods {
push_string(self.state, "__index");
ffi::lua_pushvalue(self.state, -1);
ffi::lua_gettable(self.state, -3);
self.push_value(
self.state,
Value::Function(self.create_callback_function(m)),
);
ffi::lua_pushcclosure(self.state, meta_index_impl, 2);
ffi::lua_rawset(self.state, -3);
} else {
let name = match k {
MetaMethod::Add => "__add",
MetaMethod::Sub => "__sub",
MetaMethod::Mul => "__mul",
MetaMethod::Div => "__div",
MetaMethod::Mod => "__mod",
MetaMethod::Pow => "__pow",
MetaMethod::Unm => "__unm",
MetaMethod::IDiv => "__idiv",
MetaMethod::BAnd => "__band",
MetaMethod::BOr => "__bor",
MetaMethod::BXor => "__bxor",
MetaMethod::BNot => "__bnot",
MetaMethod::Shl => "__shl",
MetaMethod::Shr => "__shr",
MetaMethod::Concat => "__concat",
MetaMethod::Len => "__len",
MetaMethod::Eq => "__eq",
MetaMethod::Lt => "__lt",
MetaMethod::Le => "__le",
MetaMethod::Index => "__index",
MetaMethod::NewIndex => "__newindex",
MetaMethod::Call => "__call",
MetaMethod::ToString => "__tostring",
};
push_string(self.state, name);
self.push_value(
self.state,
Value::Function(self.create_callback_function(m)),
);
ffi::lua_rawset(self.state, -3);
}
}
push_string(self.state, "__gc");
ffi::lua_pushcfunction(self.state, userdata_destructor::<RefCell<T>>);
push_string(self.state, name);
self.push_value(
self.state,
Value::Function(self.create_callback_function(m)),
);
ffi::lua_rawset(self.state, -3);
push_string(self.state, "__metatable");
ffi::lua_pushboolean(self.state, 0);
ffi::lua_rawset(self.state, -3);
let id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX);
entry.insert(id);
id
}
}
push_string(self.state, "__gc");
ffi::lua_pushcfunction(self.state, userdata_destructor::<RefCell<T>>);
ffi::lua_rawset(self.state, -3);
push_string(self.state, "__metatable");
ffi::lua_pushboolean(self.state, 0);
ffi::lua_rawset(self.state, -3);
let id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX);
(*registered_userdata).insert(TypeId::of::<T>(), id);
id
})
}
}

View File

@ -560,6 +560,39 @@ fn test_pcall_xpcall() {
.call::<_, ()>(());
}
#[test]
#[should_panic]
fn test_recursive_callback_panic() {
let lua = Lua::new();
let mut v = Some(Box::new(123));
let f = lua.create_function::<_, (), _>(move |lua, mutate: bool| {
if mutate {
v = None;
} else {
// Produce a mutable reference
let r = v.as_mut().unwrap();
// Whoops, this will recurse into the function and produce another mutable reference!
lua.globals()
.get::<_, Function>("f")
.unwrap()
.call::<_, ()>(true)
.unwrap();
println!("Should not get here, mutable aliasing has occurred!");
println!("value at {:p}", r as *mut _);
println!("value is {}", r);
}
Ok(())
});
lua.globals().set("f", f).unwrap();
lua.globals()
.get::<_, Function>("f")
.unwrap()
.call::<_, ()>(false)
.unwrap();
}
// TODO: Need to use compiletest-rs or similar to make sure these don't compile.
/*
#[test]

View File

@ -536,18 +536,18 @@ mod tests {
lua.eval::<()>(
r#"
local tbl = setmetatable({
userdata = userdata
}, { __gc = function(self)
-- resurrect userdata
hatch = self.userdata
end })
local tbl = setmetatable({
userdata = userdata
}, { __gc = function(self)
-- resurrect userdata
hatch = self.userdata
end })
tbl = nil
userdata = nil -- make table and userdata collectable
collectgarbage("collect")
hatch:access()
"#,
tbl = nil
userdata = nil -- make table and userdata collectable
collectgarbage("collect")
hatch:access()
"#,
None,
).unwrap();
}

View File

@ -257,8 +257,8 @@ pub unsafe fn handle_error(state: *mut ffi::lua_State, err: c_int) -> Result<()>
Err(err)
} else if is_wrapped_panic(state, -1) {
let panic = &mut *get_userdata::<WrappedPanic>(state, -1);
if let Some(p) = panic.0.take() {
let panic = get_userdata::<WrappedPanic>(state, -1);
if let Some(p) = (*panic).0.take() {
ffi::lua_settop(state, 0);
resume_unwind(p);
} else {
@ -317,17 +317,16 @@ pub unsafe fn push_string(state: *mut ffi::lua_State, s: &str) {
ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len());
}
pub unsafe fn push_userdata<T>(state: *mut ffi::lua_State, t: T) -> *mut T {
pub unsafe fn push_userdata<T>(state: *mut ffi::lua_State, t: T) {
let ud = ffi::lua_newuserdata(state, mem::size_of::<Option<T>>()) as *mut Option<T>;
ptr::write(ud, None);
*ud = Some(t);
(*ud).as_mut().unwrap()
ptr::write(ud, Some(t));
}
pub unsafe fn get_userdata<T>(state: *mut ffi::lua_State, index: c_int) -> *mut T {
let ud = ffi::lua_touserdata(state, index) as *mut Option<T>;
lua_assert!(state, !ud.is_null());
(*ud).as_mut().expect("access of expired userdata")
lua_assert!(state, (*ud).is_some(), "access of expired userdata");
(*ud).as_mut().unwrap()
}
pub unsafe extern "C" fn userdata_destructor<T>(state: *mut ffi::lua_State) -> c_int {
@ -552,8 +551,8 @@ pub struct WrappedPanic(pub Option<Box<Any + Send>>);
pub unsafe fn push_wrapped_error(state: *mut ffi::lua_State, err: Error) {
unsafe extern "C" fn error_tostring(state: *mut ffi::lua_State) -> c_int {
callback_error(state, || if is_wrapped_error(state, -1) {
let error = &*get_userdata::<WrappedError>(state, -1);
push_string(state, &error.0.to_string());
let error = get_userdata::<WrappedError>(state, -1);
push_string(state, &(*error).0.to_string());
ffi::lua_remove(state, -2);
Ok(1)