diff --git a/examples/guided_tour.rs b/examples/guided_tour.rs index 85cecda..c7dab44 100644 --- a/examples/guided_tour.rs +++ b/examples/guided_tour.rs @@ -127,7 +127,7 @@ fn guided_tour() -> Result<()> { struct Vec2(f32, f32); impl UserData for Vec2 { - fn add_methods(methods: &mut UserDataMethods) { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("magnitude", |_, vec, ()| { let mag_squared = vec.0 * vec.0 + vec.1 * vec.1; Ok(mag_squared.sqrt()) diff --git a/src/conversion.rs b/src/conversion.rs index b65eb4f..b66a9f0 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -112,13 +112,13 @@ impl<'lua> FromLua<'lua> for AnyUserData<'lua> { } } -impl<'lua, T: Send + UserData> ToLua<'lua> for T { +impl<'lua, T: 'static + Send + UserData> ToLua<'lua> for T { fn to_lua(self, lua: &'lua Lua) -> Result> { Ok(Value::UserData(lua.create_userdata(self)?)) } } -impl<'lua, T: UserData + Clone> FromLua<'lua> for T { +impl<'lua, T: 'static + UserData + Clone> FromLua<'lua> for T { fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { match value { Value::UserData(ud) => Ok(ud.borrow::()?.clone()), diff --git a/src/lua.rs b/src/lua.rs index 4ea11c2..15abb91 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::ffi::CString; use std::marker::PhantomData; use std::os::raw::{c_char, c_int, c_void}; +use std::string::String as StdString; use std::sync::{Arc, Mutex}; use std::{mem, ptr, str}; @@ -20,8 +21,9 @@ use types::{Callback, Integer, LightUserData, LuaRef, Number, RegistryKey}; use userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; use util::{ assert_stack, callback_error, check_stack, gc_guard, get_userdata, get_wrapped_error, - init_error_metatables, main_state, pop_error, protect_lua, protect_lua_closure, push_string, - push_userdata, push_wrapped_error, safe_pcall, safe_xpcall, userdata_destructor, StackGuard, + init_error_metatables, init_userdata_metatable, main_state, pop_error, protect_lua, + protect_lua_closure, push_string, push_userdata, push_wrapped_error, safe_pcall, safe_xpcall, + userdata_destructor, StackGuard, }; use value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; @@ -312,7 +314,7 @@ impl Lua { /// Create a Lua userdata object from a custom userdata type. pub fn create_userdata(&self, data: T) -> Result where - T: Send + UserData, + T: 'static + Send + UserData, { unsafe { self.make_userdata(data) } } @@ -328,15 +330,15 @@ impl Lua { } /// Calls the given function with a `Scope` parameter, giving the function the ability to create - /// userdata from rust types that are !Send, and rust callbacks that are !Send and not 'static. + /// userdata and callbacks from rust types that are !Send or non-'static. /// /// The lifetime of any function or userdata created through `Scope` lasts only until the /// completion of this method call, on completion all such created values are automatically /// dropped and Lua references to them are invalidated. If a script accesses a value created /// through `Scope` outside of this method, a Lua error will result. Since we can ensure the /// lifetime of values created through `Scope`, and we know that `Lua` cannot be sent to another - /// thread while `Scope` is live, it is safe to allow !Send datatypes and functions whose - /// lifetimes only outlive the scope lifetime. + /// thread while `Scope` is live, it is safe to allow !Send datatypes and whose lifetimes only + /// outlive the scope lifetime. /// /// Handles that `Lua::scope` produces have a `'lua` lifetime of the scope parameter, to prevent /// the handles from escaping the callback. However, this is not the only way for values to @@ -768,27 +770,7 @@ impl Lua { } } - pub(crate) unsafe fn userdata_metatable(&self) -> Result { - // Used if both an __index metamethod is set and regular methods, checks methods table - // first, then __index metamethod. - unsafe extern "C" fn meta_index_impl(state: *mut ffi::lua_State) -> c_int { - ffi::luaL_checkstack(state, 2, ptr::null()); - - ffi::lua_pushvalue(state, -1); - ffi::lua_gettable(state, ffi::lua_upvalueindex(1)); - if ffi::lua_isnil(state, -1) == 0 { - ffi::lua_insert(state, -3); - ffi::lua_pop(state, 2); - 1 - } else { - ffi::lua_pop(state, 1); - ffi::lua_pushvalue(state, ffi::lua_upvalueindex(2)); - ffi::lua_insert(state, -3); - ffi::lua_call(state, 2, 1); - 1 - } - } - + pub(crate) unsafe fn userdata_metatable(&self) -> Result { if let Some(table_id) = (*extra_data(self.state)) .registered_userdata .get(&TypeId::of::()) @@ -797,27 +779,29 @@ impl Lua { } let _sg = StackGuard::new(self.state); - assert_stack(self.state, 6); + assert_stack(self.state, 8); - let mut methods = UserDataMethods { - methods: HashMap::new(), - meta_methods: HashMap::new(), - _type: PhantomData, - }; + let mut methods = StaticUserDataMethods::default(); T::add_methods(&mut methods); protect_lua_closure(self.state, 0, 1, |state| { ffi::lua_newtable(state); })?; + for (k, m) in methods.meta_methods { + push_string(self.state, k.name())?; + self.push_value(Value::Function(self.create_callback(m)?)); - let has_methods = !methods.methods.is_empty(); + protect_lua_closure(self.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } - if has_methods { - push_string(self.state, "__index")?; + if methods.methods.is_empty() { + init_userdata_metatable::>(self.state, -1, None)?; + } else { protect_lua_closure(self.state, 0, 1, |state| { ffi::lua_newtable(state); })?; - for (k, m) in methods.methods { push_string(self.state, &k)?; self.push_value(Value::Function(self.create_callback(m)?)); @@ -826,70 +810,10 @@ impl Lua { })?; } - protect_lua_closure(self.state, 3, 1, |state| { - ffi::lua_rawset(state, -3); - })?; + init_userdata_metatable::>(self.state, -2, Some(-1))?; + ffi::lua_pop(self.state, 1); } - for (k, m) in methods.meta_methods { - if k == MetaMethod::Index && has_methods { - push_string(self.state, "__index")?; - ffi::lua_pushvalue(self.state, -1); - ffi::lua_gettable(self.state, -3); - self.push_value(Value::Function(self.create_callback(m)?)); - protect_lua_closure(self.state, 2, 1, |state| { - ffi::lua_pushcclosure(state, meta_index_impl, 2); - })?; - - protect_lua_closure(self.state, 3, 1, |state| { - ffi::lua_rawset(state, -3); - })?; - } else { - let name = match k { - MetaMethod::Add => "__add", - MetaMethod::Sub => "__sub", - MetaMethod::Mul => "__mul", - MetaMethod::Div => "__div", - MetaMethod::Mod => "__mod", - MetaMethod::Pow => "__pow", - MetaMethod::Unm => "__unm", - MetaMethod::IDiv => "__idiv", - MetaMethod::BAnd => "__band", - MetaMethod::BOr => "__bor", - MetaMethod::BXor => "__bxor", - MetaMethod::BNot => "__bnot", - MetaMethod::Shl => "__shl", - MetaMethod::Shr => "__shr", - MetaMethod::Concat => "__concat", - MetaMethod::Len => "__len", - MetaMethod::Eq => "__eq", - MetaMethod::Lt => "__lt", - MetaMethod::Le => "__le", - MetaMethod::Index => "__index", - MetaMethod::NewIndex => "__newindex", - MetaMethod::Call => "__call", - MetaMethod::ToString => "__tostring", - }; - push_string(self.state, name)?; - self.push_value(Value::Function(self.create_callback(m)?)); - protect_lua_closure(self.state, 3, 1, |state| { - ffi::lua_rawset(state, -3); - })?; - } - } - - push_string(self.state, "__gc")?; - ffi::lua_pushcfunction(self.state, userdata_destructor::>); - protect_lua_closure(self.state, 3, 1, |state| { - ffi::lua_rawset(state, -3); - })?; - - push_string(self.state, "__metatable")?; - ffi::lua_pushboolean(self.state, 0); - protect_lua_closure(self.state, 3, 1, |state| { - ffi::lua_rawset(state, -3); - })?; - let id = gc_guard(self.state, || { ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX) }); @@ -899,6 +823,14 @@ impl Lua { Ok(id) } + // Creates a Function out of a Callback containing a 'static Fn. This is safe ONLY because the + // Fn is 'static, otherwise it could capture 'callback arguments improperly. Without ATCs, we + // cannot easily deal with the "correct" callback type of: + // + // Box Fn(&'lua Lua, MultiValue<'lua>) -> Result>)> + // + // So we instead use a caller provided lifetime, which without the 'static requirement would be + // unsafe. pub(crate) fn create_callback<'lua, 'callback>( &'lua self, func: Callback<'callback, 'static>, @@ -965,7 +897,7 @@ impl Lua { // Does not require Send bounds, which can lead to unsafety. pub(crate) unsafe fn make_userdata(&self, data: T) -> Result where - T: UserData, + T: 'static + UserData, { let _sg = StackGuard::new(self.state); assert_stack(self.state, 4); @@ -1130,3 +1062,170 @@ unsafe fn ref_stack_pop(extra: *mut ExtraData) -> c_int { } static FUNCTION_METATABLE_REGISTRY_KEY: u8 = 0; + +struct StaticUserDataMethods<'lua, T: 'static + UserData> { + methods: HashMap>, + meta_methods: HashMap>, + _type: PhantomData, +} + +impl<'lua, T: 'static + UserData> Default for StaticUserDataMethods<'lua, T> { + fn default() -> StaticUserDataMethods<'lua, T> { + StaticUserDataMethods { + methods: HashMap::new(), + meta_methods: HashMap::new(), + _type: PhantomData, + } + } +} + +impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMethods<'lua, T> { + fn add_method(&mut self, name: &str, method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + { + self.methods + .insert(name.to_owned(), Self::box_method(method)); + } + + fn add_method_mut(&mut self, name: &str, method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + { + self.methods + .insert(name.to_owned(), Self::box_method_mut(method)); + } + + fn add_function(&mut self, name: &str, function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, + { + self.methods + .insert(name.to_owned(), Self::box_function(function)); + } + + fn add_function_mut(&mut self, name: &str, function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + { + self.methods + .insert(name.to_owned(), Self::box_function_mut(function)); + } + + fn add_meta_method(&mut self, meta: MetaMethod, method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + { + self.meta_methods.insert(meta, Self::box_method(method)); + } + + fn add_meta_method_mut(&mut self, meta: MetaMethod, method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + { + self.meta_methods.insert(meta, Self::box_method_mut(method)); + } + + fn add_meta_function(&mut self, meta: MetaMethod, function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, + { + self.meta_methods.insert(meta, Self::box_function(function)); + } + + fn add_meta_function_mut(&mut self, meta: MetaMethod, function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + { + self.meta_methods + .insert(meta, Self::box_function_mut(function)); + } +} + +impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { + fn box_method(method: M) -> Callback<'lua, 'static> + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + { + Box::new(move |lua, mut args| { + if let Some(front) = args.pop_front() { + let userdata = AnyUserData::from_lua(front, lua)?; + let userdata = userdata.borrow::()?; + method(lua, &userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + } else { + Err(Error::FromLuaConversionError { + from: "missing argument", + to: "userdata", + message: None, + }) + } + }) + } + + fn box_method_mut(method: M) -> Callback<'lua, 'static> + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + { + let method = RefCell::new(method); + Box::new(move |lua, mut args| { + if let Some(front) = args.pop_front() { + let userdata = AnyUserData::from_lua(front, lua)?; + let mut userdata = userdata.borrow_mut::()?; + let mut method = method + .try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?; + (&mut *method)(lua, &mut userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + } else { + Err(Error::FromLuaConversionError { + from: "missing argument", + to: "userdata", + message: None, + }) + } + }) + } + + fn box_function(function: F) -> Callback<'lua, 'static> + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, + { + Box::new(move |lua, args| function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)) + } + + fn box_function_mut(function: F) -> Callback<'lua, 'static> + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + { + let function = RefCell::new(function); + Box::new(move |lua, args| { + let function = &mut *function + .try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?; + function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + }) + } +} diff --git a/src/scope.rs b/src/scope.rs index efda276..5820c3e 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,16 +1,23 @@ use std::any::Any; use std::cell::RefCell; +use std::collections::HashMap; use std::marker::PhantomData; use std::mem; +use std::os::raw::c_void; +use std::rc::Rc; +use std::string::String as StdString; use error::{Error, Result}; use ffi; use function::Function; use lua::Lua; use types::Callback; -use userdata::{AnyUserData, UserData}; -use util::{assert_stack, take_userdata, StackGuard}; -use value::{FromLuaMulti, ToLuaMulti}; +use userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; +use util::{ + assert_stack, init_userdata_metatable, protect_lua_closure, push_string, push_userdata, + take_userdata, StackGuard, +}; +use value::{FromLuaMulti, MultiValue, ToLuaMulti, Value}; /// Constructed by the [`Lua::scope`] method, allows temporarily passing to Lua userdata that is /// !Send, and callbacks that are !Send and not 'static. @@ -47,31 +54,19 @@ impl<'scope> Scope<'scope> { R: ToLuaMulti<'lua>, F: 'scope + Fn(&'lua Lua, A) -> Result, { + // Safe, because 'scope must outlive 'lua (due to Self containing 'scope), however the + // callback itself must be 'scope lifetime, so the function should not be able to capture + // anything of 'lua lifetime. 'scope can't be shortened due to being invariant, and the + // 'lua lifetime here can't be enlarged due to coming from a universal quantification in + // Lua::scope. + // + // I hope I got this explanation right, but in any case this is tested with compiletest_rs + // to make sure callbacks can't capture handles with lifetimes outside the scope, inside the + // scope, and owned inside the callback itself. unsafe { - let f = Box::new(move |lua, args| { + self.create_callback(Box::new(move |lua, args| { func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - }); - let f = mem::transmute::, Callback<'lua, 'static>>(f); - let f = self.lua.create_callback(f)?; - - let mut destructors = self.destructors.borrow_mut(); - let f_destruct = f.0.clone(); - destructors.push(Box::new(move || { - let state = f_destruct.lua.state; - let _sg = StackGuard::new(state); - assert_stack(state, 2); - f_destruct.lua.push_ref(&f_destruct); - - ffi::lua_getupvalue(state, -1, 1); - let ud = take_userdata::(state); - - ffi::lua_pushnil(state); - ffi::lua_setupvalue(state, -2, 1); - - ffi::lua_pop(state, 1); - Box::new(ud) - })); - Ok(f) + })) } } @@ -100,15 +95,17 @@ impl<'scope> Scope<'scope> { /// Create a Lua userdata object from a custom userdata type. /// /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on scope - /// drop, and does not require that the userdata type be Send. See [`Lua::scope`] for more - /// details. + /// drop, and does not require that the userdata type be Send (but still requires that the + /// UserData be 'static). See [`Lua::scope`] for more details. /// /// [`Lua::create_userdata`]: struct.Lua.html#method.create_userdata /// [`Lua::scope`]: struct.Lua.html#method.scope - pub fn create_userdata<'lua, T>(&'lua self, data: T) -> Result> + pub fn create_static_userdata<'lua, T>(&'lua self, data: T) -> Result> where - T: UserData, + T: 'static + UserData, { + // Safe even though T may not be Send, because the parent Lua cannot be sent to another + // thread while the Scope is alive (or the returned AnyUserData handle even). unsafe { let u = self.lua.make_userdata(data)?; let mut destructors = self.destructors.borrow_mut(); @@ -123,6 +120,193 @@ impl<'scope> Scope<'scope> { Ok(u) } } + + /// Create a Lua userdata object from a custom userdata type. + /// + /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on scope + /// drop, and does not require that the userdata type be Send or 'static. See [`Lua::scope`] for + /// more details. + /// + /// Lifting the requirement that the UserData type be 'static comes with some important + /// limitations, so if you only need to eliminate the Send requirement, it is probably better to + /// use [`Scope::create_static_userdata`] instead. + /// + /// The main limitation that comes from using non-'static userdata is that the produced userdata + /// will no longer have a `TypeId` associated with it, becuase `TypeId` can only work for + /// 'static types. This means that it is impossible, once the userdata is created, to get a + /// reference to it back *out* of an `AnyUserData` handle. This also implies that the + /// "function" type methods that can be added via [`UserDataMethods`] (the ones that accept + /// `AnyUserData` as a first parameter) are vastly less useful. Also, there is no way to re-use + /// a single metatable for multiple non-'static types, so there is a higher cost associated with + /// creating the userdata metatable each time a new userdata is created. + /// + /// [`create_static_userdata`]: #method.create_static_userdata + /// [`Lua::create_userdata`]: struct.Lua.html#method.create_userdata + /// [`Lua::scope`]: struct.Lua.html#method.scope + /// [`UserDataMethods`]: trait.UserDataMethods.html + pub fn create_userdata<'lua, T>(&'lua self, data: T) -> Result> + where + T: 'scope + UserData, + { + let data = Rc::new(RefCell::new(data)); + + // 'callback outliving 'scope is a lie to make the types work out, required due to the + // inability to work with the "correct" universally quantified callback type. This is safe + // though, because actual method callbacks are all 'static so they can't capture 'callback + // handles anyway. + fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>( + scope: &'lua Scope<'scope>, + data: Rc>, + method: NonStaticMethod<'callback, T>, + ) -> Result> { + // On methods that actually receive the userdata, we fake a type check on the passed in + // userdata, where we pretend there is a unique type per call to Scope::create_userdata. + // You can grab a method from a userdata and call it on a mismatched userdata type, + // which when using normal 'static userdata will fail with a type mismatch, but here + // without this check would proceed as though you had called the method on the original + // value (since we otherwise completely ignore the first argument). + let check_data = data.clone(); + let check_ud_type = move |lua: &Lua, value| { + if let Some(value) = value { + if let Value::UserData(u) = value { + unsafe { + let _sg = StackGuard::new(lua.state); + assert_stack(lua.state, 1); + lua.push_ref(&u.0); + ffi::lua_getuservalue(lua.state, -1); + return ffi::lua_touserdata(lua.state, -1) + == check_data.as_ptr() as *mut c_void; + } + } + } + + false + }; + + match method { + NonStaticMethod::Method(method) => { + let method_data = data.clone(); + let f = Box::new(move |lua, mut args: MultiValue<'callback>| { + if !check_ud_type(lua, args.pop_front()) { + return Err(Error::UserDataTypeMismatch); + } + let data = method_data + .try_borrow() + .map_err(|_| Error::UserDataBorrowError)?; + method(lua, &*data, args) + }); + unsafe { scope.create_callback(f) } + } + NonStaticMethod::MethodMut(method) => { + let method = RefCell::new(method); + let method_data = data.clone(); + let f = Box::new(move |lua, mut args: MultiValue<'callback>| { + if !check_ud_type(lua, args.pop_front()) { + return Err(Error::UserDataTypeMismatch); + } + let mut method = method + .try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?; + let mut data = method_data + .try_borrow_mut() + .map_err(|_| Error::UserDataBorrowMutError)?; + (&mut *method)(lua, &mut *data, args) + }); + unsafe { scope.create_callback(f) } + } + NonStaticMethod::Function(function) => unsafe { scope.create_callback(function) }, + NonStaticMethod::FunctionMut(function) => { + let function = RefCell::new(function); + let f = Box::new(move |lua, args| { + (&mut *function + .try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?)( + lua, args + ) + }); + unsafe { scope.create_callback(f) } + } + } + } + + let mut ud_methods = NonStaticUserDataMethods::default(); + T::add_methods(&mut ud_methods); + + unsafe { + let lua = self.lua; + let _sg = StackGuard::new(lua.state); + assert_stack(lua.state, 6); + + push_userdata(lua.state, ())?; + ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void); + ffi::lua_setuservalue(lua.state, -2); + + protect_lua_closure(lua.state, 0, 1, move |state| { + ffi::lua_newtable(state); + })?; + + for (k, m) in ud_methods.meta_methods { + push_string(lua.state, k.name())?; + lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?)); + + protect_lua_closure(lua.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + + if ud_methods.methods.is_empty() { + init_userdata_metatable::<()>(lua.state, -1, None)?; + } else { + protect_lua_closure(lua.state, 0, 1, |state| { + ffi::lua_newtable(state); + })?; + for (k, m) in ud_methods.methods { + push_string(lua.state, &k)?; + lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?)); + protect_lua_closure(lua.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + + init_userdata_metatable::<()>(lua.state, -2, Some(-1))?; + ffi::lua_pop(lua.state, 1); + } + + ffi::lua_setmetatable(lua.state, -2); + + Ok(AnyUserData(lua.pop_ref())) + } + } + + // Unsafe, because the callback (since it is non-'static) can capture any value with 'callback + // scope, such as improperly holding onto an argument. So in order for this to be safe, the + // callback must NOT capture any arguments. + unsafe fn create_callback<'lua, 'callback>( + &'lua self, + f: Callback<'callback, 'scope>, + ) -> Result> { + let f = mem::transmute::, Callback<'callback, 'static>>(f); + let f = self.lua.create_callback(f)?; + + let mut destructors = self.destructors.borrow_mut(); + let f_destruct = f.0.clone(); + destructors.push(Box::new(move || { + let state = f_destruct.lua.state; + let _sg = StackGuard::new(state); + assert_stack(state, 2); + f_destruct.lua.push_ref(&f_destruct); + + ffi::lua_getupvalue(state, -1, 1); + let ud = take_userdata::(state); + + ffi::lua_pushnil(state); + ffi::lua_setupvalue(state, -2, 1); + + ffi::lua_pop(state, 1); + Box::new(ud) + })); + Ok(f) + } } impl<'scope> Drop for Scope<'scope> { @@ -140,3 +324,138 @@ impl<'scope> Drop for Scope<'scope> { drop(to_drop); } } + +enum NonStaticMethod<'lua, T> { + Method(Box) -> Result>>), + MethodMut(Box) -> Result>>), + Function(Box) -> Result>>), + FunctionMut(Box) -> Result>>), +} + +struct NonStaticUserDataMethods<'lua, T: UserData> { + methods: HashMap>, + meta_methods: HashMap>, +} + +impl<'lua, T: UserData> Default for NonStaticUserDataMethods<'lua, T> { + fn default() -> NonStaticUserDataMethods<'lua, T> { + NonStaticUserDataMethods { + methods: HashMap::new(), + meta_methods: HashMap::new(), + } + } +} + +impl<'lua, T: UserData> UserDataMethods<'lua, T> for NonStaticUserDataMethods<'lua, T> { + fn add_method(&mut self, name: &str, method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + { + self.methods.insert( + name.to_owned(), + NonStaticMethod::Method(Box::new(move |lua, ud, args| { + method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } + + fn add_method_mut(&mut self, name: &str, mut method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + { + self.methods.insert( + name.to_owned(), + NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { + method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } + + fn add_function(&mut self, name: &str, function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, + { + self.methods.insert( + name.to_owned(), + NonStaticMethod::Function(Box::new(move |lua, args| { + function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } + + fn add_function_mut(&mut self, name: &str, mut function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + { + self.methods.insert( + name.to_owned(), + NonStaticMethod::FunctionMut(Box::new(move |lua, args| { + function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } + + fn add_meta_method(&mut self, meta: MetaMethod, method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + { + self.meta_methods.insert( + meta, + NonStaticMethod::Method(Box::new(move |lua, ud, args| { + method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } + + fn add_meta_method_mut(&mut self, meta: MetaMethod, mut method: M) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + { + self.meta_methods.insert( + meta, + NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { + method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } + + fn add_meta_function(&mut self, meta: MetaMethod, function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + Fn(&'lua Lua, A) -> Result, + { + self.meta_methods.insert( + meta, + NonStaticMethod::Function(Box::new(move |lua, args| { + function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } + + fn add_meta_function_mut(&mut self, meta: MetaMethod, mut function: F) + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + { + self.meta_methods.insert( + meta, + NonStaticMethod::FunctionMut(Box::new(move |lua, args| { + function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + ); + } +} diff --git a/src/tests/scope.rs b/src/tests/scope.rs index 1e4c255..0d04acc 100644 --- a/src/tests/scope.rs +++ b/src/tests/scope.rs @@ -1,7 +1,7 @@ use std::cell::Cell; use std::rc::Rc; -use {Error, Function, Lua, String, UserData, UserDataMethods}; +use {Error, Function, Lua, MetaMethod, String, UserData, UserDataMethods}; #[test] fn scope_func() { @@ -40,7 +40,7 @@ fn scope_drop() { struct MyUserdata(Rc<()>); impl UserData for MyUserdata { - fn add_methods(methods: &mut UserDataMethods) { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("method", |_, _, ()| Ok(())); } } @@ -51,7 +51,9 @@ fn scope_drop() { lua.globals() .set( "test", - scope.create_userdata(MyUserdata(rc.clone())).unwrap(), + scope + .create_static_userdata(MyUserdata(rc.clone())) + .unwrap(), ) .unwrap(); assert_eq!(Rc::strong_count(&rc), 2); @@ -98,3 +100,136 @@ fn outer_lua_access() { }); assert_eq!(table.get::<_, String>("a").unwrap(), "b"); } + +#[test] +fn scope_userdata_methods() { + struct MyUserData<'a>(&'a Cell); + + impl<'a> UserData for MyUserData<'a> { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("inc", |_, data, ()| { + data.0.set(data.0.get() + 1); + Ok(()) + }); + + methods.add_method("dec", |_, data, ()| { + data.0.set(data.0.get() - 1); + Ok(()) + }); + } + } + + let lua = Lua::new(); + + let i = Cell::new(42); + lua.scope(|scope| { + let f: Function = + lua.eval( + r#" + function(u) + u:inc() + u:inc() + u:inc() + u:dec() + end + "#, + None, + ).unwrap(); + + f.call::<_, ()>(scope.create_userdata(MyUserData(&i)).unwrap()) + .unwrap(); + }); + + assert_eq!(i.get(), 44); +} + +#[test] +fn scope_userdata_functions() { + struct MyUserData<'a>(&'a i64); + + impl<'a> UserData for MyUserData<'a> { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_meta_function(MetaMethod::Add, |lua, ()| { + let globals = lua.globals(); + globals.set("i", globals.get::<_, i64>("i")? + 1)?; + Ok(()) + }); + methods.add_meta_function(MetaMethod::Sub, |lua, ()| { + let globals = lua.globals(); + globals.set("i", globals.get::<_, i64>("i")? + 1)?; + Ok(()) + }); + } + } + + let lua = Lua::new(); + let f = + lua.exec::( + r#" + i = 0 + return function(u) + _ = u + u + _ = u - 1 + _ = 1 + u + end + "#, + None, + ).unwrap(); + + let dummy = 0; + lua.scope(|scope| { + f.call::<_, ()>(scope.create_userdata(MyUserData(&dummy)).unwrap()) + .unwrap(); + }); + + assert_eq!(lua.globals().get::<_, i64>("i").unwrap(), 3); +} + +#[test] +fn scope_userdata_mismatch() { + struct MyUserData<'a>(&'a Cell); + + impl<'a> UserData for MyUserData<'a> { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("inc", |_, data, ()| { + data.0.set(data.0.get() + 1); + Ok(()) + }); + } + } + + let lua = Lua::new(); + lua.exec::<()>( + r#" + function okay(a, b) + a.inc(a) + b.inc(b) + end + + function bad(a, b) + a.inc(b) + end + "#, + None, + ).unwrap(); + + let a = Cell::new(1); + let b = Cell::new(1); + + let okay: Function = lua.globals().get("okay").unwrap(); + let bad: Function = lua.globals().get("bad").unwrap(); + + lua.scope(|scope| { + let au = scope.create_userdata(MyUserData(&a)).unwrap(); + let bu = scope.create_userdata(MyUserData(&b)).unwrap(); + assert!(okay.call::<_, ()>((au.clone(), bu.clone())).is_ok()); + match bad.call::<_, ()>((au, bu)) { + Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { + Error::UserDataTypeMismatch => {} + ref other => panic!("wrong error type {:?}", other), + }, + Err(other) => panic!("wrong error type {:?}", other), + Ok(_) => panic!("incorrectly returned Ok"), + } + }); +} diff --git a/src/tests/userdata.rs b/src/tests/userdata.rs index 56d70af..46e5d0e 100644 --- a/src/tests/userdata.rs +++ b/src/tests/userdata.rs @@ -15,10 +15,10 @@ fn test_user_data() { let userdata1 = lua.create_userdata(UserData1(1)).unwrap(); let userdata2 = lua.create_userdata(UserData2(Box::new(2))).unwrap(); - assert!(userdata1.is::().unwrap()); - assert!(!userdata1.is::().unwrap()); - assert!(userdata2.is::().unwrap()); - assert!(!userdata2.is::().unwrap()); + assert!(userdata1.is::()); + assert!(!userdata1.is::()); + assert!(userdata2.is::()); + assert!(!userdata2.is::()); assert_eq!(userdata1.borrow::().unwrap().0, 1); assert_eq!(*userdata2.borrow::().unwrap().0, 2); @@ -29,7 +29,7 @@ fn test_methods() { struct MyUserData(i64); impl UserData for MyUserData { - fn add_methods(methods: &mut UserDataMethods) { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("get_value", |_, data, ()| Ok(data.0)); methods.add_method_mut("set_value", |_, data, args| { data.0 = args; @@ -69,7 +69,7 @@ fn test_metamethods() { struct MyUserData(i64); impl UserData for MyUserData { - fn add_methods(methods: &mut UserDataMethods) { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("get", |_, data, ()| Ok(data.0)); methods.add_meta_function( MetaMethod::Add, @@ -117,7 +117,7 @@ fn test_gc_userdata() { } impl UserData for MyUserdata { - fn add_methods(methods: &mut UserDataMethods) { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("access", |_, this, ()| { assert!(this.id == 123); Ok(()) diff --git a/src/userdata.rs b/src/userdata.rs index 6bacdae..0a22e68 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -1,12 +1,9 @@ use std::cell::{Ref, RefCell, RefMut}; -use std::collections::HashMap; -use std::marker::PhantomData; -use std::string::String as StdString; use error::{Error, Result}; use ffi; use lua::Lua; -use types::{Callback, LuaRef}; +use types::LuaRef; use util::{assert_stack, get_userdata, StackGuard}; use value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti}; @@ -68,16 +65,40 @@ pub enum MetaMethod { ToString, } +impl MetaMethod { + pub(crate) fn name(self) -> &'static str { + match self { + MetaMethod::Add => "__add", + MetaMethod::Sub => "__sub", + MetaMethod::Mul => "__mul", + MetaMethod::Div => "__div", + MetaMethod::Mod => "__mod", + MetaMethod::Pow => "__pow", + MetaMethod::Unm => "__unm", + MetaMethod::IDiv => "__idiv", + MetaMethod::BAnd => "__band", + MetaMethod::BOr => "__bor", + MetaMethod::BXor => "__bxor", + MetaMethod::BNot => "__bnot", + MetaMethod::Shl => "__shl", + MetaMethod::Shr => "__shr", + MetaMethod::Concat => "__concat", + MetaMethod::Len => "__len", + MetaMethod::Eq => "__eq", + MetaMethod::Lt => "__lt", + MetaMethod::Le => "__le", + MetaMethod::Index => "__index", + MetaMethod::NewIndex => "__newindex", + MetaMethod::Call => "__call", + MetaMethod::ToString => "__tostring", + } + } +} + /// Method registry for [`UserData`] implementors. /// /// [`UserData`]: trait.UserData.html -pub struct UserDataMethods<'lua, T> { - pub(crate) methods: HashMap>, - pub(crate) meta_methods: HashMap>, - pub(crate) _type: PhantomData, -} - -impl<'lua, T: UserData> UserDataMethods<'lua, T> { +pub trait UserDataMethods<'lua, T: UserData> { /// Add a method which accepts a `&T` as the first parameter. /// /// Regular methods are implemented by overriding the `__index` metamethod and returning the @@ -85,30 +106,22 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { /// /// If `add_meta_method` is used to set the `__index` metamethod, the `__index` metamethod will /// be used as a fall-back if no regular method is found. - pub fn add_method(&mut self, name: &str, method: M) + fn add_method(&mut self, name: &str, method: M) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, - { - self.methods - .insert(name.to_owned(), Self::box_method(method)); - } + M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result; /// Add a regular method which accepts a `&mut T` as the first parameter. /// /// Refer to [`add_method`] for more information about the implementation. /// /// [`add_method`]: #method.add_method - pub fn add_method_mut(&mut self, name: &str, method: M) + fn add_method_mut(&mut self, name: &str, method: M) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.methods - .insert(name.to_owned(), Self::box_method_mut(method)); - } + M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result; /// Add a regular method as a function which accepts generic arguments, the first argument will /// always be a `UserData` of type T. @@ -117,15 +130,11 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { /// /// [`add_method`]: #method.add_method /// [`add_method_mut`]: #method.add_method_mut - pub fn add_function(&mut self, name: &str, function: F) + fn add_function(&mut self, name: &str, function: F) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, - { - self.methods - .insert(name.to_owned(), Self::box_function(function)); - } + F: 'static + Send + Fn(&'lua Lua, A) -> Result; /// Add a regular method as a mutable function which accepts generic arguments, the first /// argument will always be a `UserData` of type T. @@ -133,15 +142,11 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { /// This is a version of [`add_function`] that accepts a FnMut argument. /// /// [`add_function`]: #method.add_function - pub fn add_function_mut(&mut self, name: &str, function: F) + fn add_function_mut(&mut self, name: &str, function: F) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, - { - self.methods - .insert(name.to_owned(), Self::box_function_mut(function)); - } + F: 'static + Send + FnMut(&'lua Lua, A) -> Result; /// Add a metamethod which accepts a `&T` as the first parameter. /// @@ -151,14 +156,11 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { /// side has a metatable. To prevent this, use [`add_meta_function`]. /// /// [`add_meta_function`]: #method.add_meta_function - pub fn add_meta_method(&mut self, meta: MetaMethod, method: M) + fn add_meta_method(&mut self, meta: MetaMethod, method: M) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, - { - self.meta_methods.insert(meta, Self::box_method(method)); - } + M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result; /// Add a metamethod as a function which accepts a `&mut T` as the first parameter. /// @@ -168,113 +170,33 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { /// side has a metatable. To prevent this, use [`add_meta_function`]. /// /// [`add_meta_function`]: #method.add_meta_function - pub fn add_meta_method_mut(&mut self, meta: MetaMethod, method: M) + fn add_meta_method_mut(&mut self, meta: MetaMethod, method: M) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.meta_methods.insert(meta, Self::box_method_mut(method)); - } + M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result; /// Add a metamethod which accepts generic arguments. /// /// Metamethods for binary operators can be triggered if either the left or right argument to /// the binary operator has a metatable, so the first argument here is not necessarily a /// userdata of type `T`. - pub fn add_meta_function(&mut self, meta: MetaMethod, function: F) + fn add_meta_function(&mut self, meta: MetaMethod, function: F) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, - { - self.meta_methods.insert(meta, Self::box_function(function)); - } + F: 'static + Send + Fn(&'lua Lua, A) -> Result; /// Add a metamethod as a mutable function which accepts generic arguments. /// /// This is a version of [`add_meta_function`] that accepts a FnMut argument. /// /// [`add_meta_function`]: #method.add_meta_function - pub fn add_meta_function_mut(&mut self, meta: MetaMethod, function: F) + fn add_meta_function_mut(&mut self, meta: MetaMethod, function: F) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, - { - self.meta_methods - .insert(meta, Self::box_function_mut(function)); - } - - fn box_function(function: F) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, - { - Box::new(move |lua, args| function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)) - } - - fn box_function_mut(function: F) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, - { - let function = RefCell::new(function); - Box::new(move |lua, args| { - let function = &mut *function - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?; - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - }) - } - - fn box_method(method: M) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, - { - Box::new(move |lua, mut args| { - if let Some(front) = args.pop_front() { - let userdata = AnyUserData::from_lua(front, lua)?; - let userdata = userdata.borrow::()?; - method(lua, &userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } else { - Err(Error::FromLuaConversionError { - from: "missing argument", - to: "userdata", - message: None, - }) - } - }) - } - - fn box_method_mut(method: M) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, - { - let method = RefCell::new(method); - Box::new(move |lua, mut args| { - if let Some(front) = args.pop_front() { - let userdata = AnyUserData::from_lua(front, lua)?; - let mut userdata = userdata.borrow_mut::()?; - let mut method = method - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?; - (&mut *method)(lua, &mut userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } else { - Err(Error::FromLuaConversionError { - from: "missing argument", - to: "userdata", - message: None, - }) - } - }) - } + F: 'static + Send + FnMut(&'lua Lua, A) -> Result; } /// Trait for custom userdata types. @@ -315,7 +237,7 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { /// struct MyUserData(i32); /// /// impl UserData for MyUserData { -/// fn add_methods(methods: &mut UserDataMethods) { +/// fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { /// methods.add_method("get", |_, this, _: ()| { /// Ok(this.0) /// }); @@ -350,10 +272,10 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> { /// /// [`ToLua`]: trait.ToLua.html /// [`FromLua`]: trait.FromLua.html -/// [`UserDataMethods`]: struct.UserDataMethods.html -pub trait UserData: 'static + Sized { +/// [`UserDataMethods`]: trait.UserDataMethods.html +pub trait UserData: Sized { /// Adds custom methods and operators specific to this userdata. - fn add_methods(_methods: &mut UserDataMethods) {} + fn add_methods<'lua, T: UserDataMethods<'lua, Self>>(_methods: &mut T) {} } /// Handle to an internal Lua userdata for any type that implements [`UserData`]. @@ -377,11 +299,11 @@ pub struct AnyUserData<'lua>(pub(crate) LuaRef<'lua>); impl<'lua> AnyUserData<'lua> { /// Checks whether the type of this userdata is `T`. - pub fn is(&self) -> Result { + pub fn is(&self) -> bool { match self.inspect(|_: &RefCell| Ok(())) { - Ok(()) => Ok(true), - Err(Error::UserDataTypeMismatch) => Ok(false), - Err(err) => Err(err), + Ok(()) => true, + Err(Error::UserDataTypeMismatch) => false, + Err(_) => unreachable!(), } } @@ -391,7 +313,7 @@ impl<'lua> AnyUserData<'lua> { /// /// Returns a `UserDataBorrowError` if the userdata is already mutably borrowed. Returns a /// `UserDataTypeMismatch` if the userdata is not of type `T`. - pub fn borrow(&self) -> Result> { + pub fn borrow(&self) -> Result> { self.inspect(|cell| Ok(cell.try_borrow().map_err(|_| Error::UserDataBorrowError)?)) } @@ -401,7 +323,7 @@ impl<'lua> AnyUserData<'lua> { /// /// Returns a `UserDataBorrowMutError` if the userdata is already borrowed. Returns a /// `UserDataTypeMismatch` if the userdata is not of type `T`. - pub fn borrow_mut(&self) -> Result> { + pub fn borrow_mut(&self) -> Result> { self.inspect(|cell| { Ok(cell .try_borrow_mut() @@ -444,7 +366,7 @@ impl<'lua> AnyUserData<'lua> { fn inspect<'a, T, R, F>(&'a self, func: F) -> Result where - T: UserData, + T: 'static + UserData, F: FnOnce(&'a RefCell) -> Result, { unsafe { diff --git a/src/util.rs b/src/util.rs index b650d9a..596e45a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -250,6 +250,81 @@ pub unsafe fn take_userdata(state: *mut ffi::lua_State) -> T { ptr::read(ud) } +// Populates the given table with the appropriate members to be a userdata metatable for the given +// type. This function takes the given table at the `metatable` index, and adds an appropriate __gc +// member to it for the given type and a __metatable entry to protect the table from script access. +// The function also, if given a `members` table index, will set up an __index metamethod to return +// the appropriate member on __index. Additionally, if there is already an __index entry on the +// given metatable, instead of simply overwriting the __index, instead the created __index method +// will capture the previous one, and use it as a fallback only if the given key is not found in the +// provided members table. Internally uses 6 stack spaces and does not call checkstack. +pub unsafe fn init_userdata_metatable( + state: *mut ffi::lua_State, + metatable: c_int, + members: Option, +) -> Result<()> { + // Used if both an __index metamethod is set and regular methods, checks methods table + // first, then __index metamethod. + unsafe extern "C" fn meta_index_impl(state: *mut ffi::lua_State) -> c_int { + ffi::luaL_checkstack(state, 2, ptr::null()); + + ffi::lua_pushvalue(state, -1); + ffi::lua_gettable(state, ffi::lua_upvalueindex(2)); + if ffi::lua_isnil(state, -1) == 0 { + ffi::lua_insert(state, -3); + ffi::lua_pop(state, 2); + 1 + } else { + ffi::lua_pop(state, 1); + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); + ffi::lua_insert(state, -3); + ffi::lua_call(state, 2, 1); + 1 + } + } + + let members = members.map(|i| ffi::lua_absindex(state, i)); + ffi::lua_pushvalue(state, metatable); + + if let Some(members) = members { + push_string(state, "__index")?; + ffi::lua_pushvalue(state, -1); + + let index_type = ffi::lua_rawget(state, -3); + if index_type == ffi::LUA_TNIL { + ffi::lua_pop(state, 1); + ffi::lua_pushvalue(state, members); + } else if index_type == ffi::LUA_TFUNCTION { + ffi::lua_pushvalue(state, members); + protect_lua_closure(state, 2, 1, |state| { + ffi::lua_pushcclosure(state, meta_index_impl, 2); + })?; + } else { + rlua_panic!("improper __index type {}", index_type); + } + + protect_lua_closure(state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + + push_string(state, "__gc")?; + ffi::lua_pushcfunction(state, userdata_destructor::); + protect_lua_closure(state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + + push_string(state, "__metatable")?; + ffi::lua_pushboolean(state, 0); + protect_lua_closure(state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + + ffi::lua_pop(state, 1); + + Ok(()) +} + pub unsafe extern "C" fn userdata_destructor(state: *mut ffi::lua_State) -> c_int { callback_error(state, || { take_userdata::(state);