diff --git a/src/error.rs b/src/error.rs index 421518a..903b189 100644 --- a/src/error.rs +++ b/src/error.rs @@ -131,6 +131,18 @@ pub enum Error { /// [`AnyUserData`]: struct.AnyUserData.html /// [`UserData`]: trait.UserData.html UserDataBorrowMutError, + /// A [`MetaMethod`] operation is restricted (typically for `__gc` or `__metatable`). + /// + /// [`MetaMethod`]: enum.MetaMethod.html + MetaMethodRestricted(StdString), + /// A [`MetaMethod`] (eg. `__index` or `__newindex`) has invalid type. + /// + /// [`MetaMethod`]: enum.MetaMethod.html + MetaMethodTypeError { + method: StdString, + type_name: &'static str, + message: Option, + }, /// A `RegistryKey` produced from a different Lua state was used. MismatchedRegistryKey, /// A Rust callback returned `Err`, raising the contained `Error` as a Lua error. @@ -203,22 +215,14 @@ impl fmt::Display for Error { fmt, "too many arguments to Function::bind" ), - Error::ToLuaConversionError { - from, - to, - ref message, - } => { + Error::ToLuaConversionError { from, to, ref message } => { write!(fmt, "error converting {} to Lua {}", from, to)?; match *message { None => Ok(()), Some(ref message) => write!(fmt, " ({})", message), } } - Error::FromLuaConversionError { - from, - to, - ref message, - } => { + Error::FromLuaConversionError { from, to, ref message } => { write!(fmt, "error converting Lua {} to {}", from, to)?; match *message { None => Ok(()), @@ -230,6 +234,14 @@ impl fmt::Display for Error { Error::UserDataDestructed => write!(fmt, "userdata has been destructed"), Error::UserDataBorrowError => write!(fmt, "userdata already mutably borrowed"), Error::UserDataBorrowMutError => write!(fmt, "userdata already borrowed"), + Error::MetaMethodRestricted(ref method) => write!(fmt, "metamethod {} is restricted", method), + Error::MetaMethodTypeError { ref method, type_name, ref message } => { + write!(fmt, "metamethod {} has unsupported type {}", method, type_name)?; + match *message { + None => Ok(()), + Some(ref message) => write!(fmt, " ({})", message), + } + } Error::MismatchedRegistryKey => { write!(fmt, "RegistryKey used from different Lua state") } diff --git a/src/lib.rs b/src/lib.rs index 1506b11..855d3e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ //! //! The [`UserData`] trait can be implemented by user-defined types to make them available to Lua. //! Methods and operators to be used from Lua can be added using the [`UserDataMethods`] API. +//! Fields are supported using the [`UserDataFields`] API. //! //! # Serde support //! @@ -59,6 +60,7 @@ //! [`FromLuaMulti`]: trait.FromLuaMulti.html //! [`Function`]: struct.Function.html //! [`UserData`]: trait.UserData.html +//! [`UserDataFields`]: trait.UserDataFields.html //! [`UserDataMethods`]: trait.UserDataMethods.html //! [`LuaSerdeExt`]: serde/trait.LuaSerdeExt.html //! [`Value`]: enum.Value.html @@ -109,7 +111,7 @@ pub use crate::string::String; pub use crate::table::{Table, TableExt, TablePairs, TableSequence}; pub use crate::thread::{Thread, ThreadStatus}; pub use crate::types::{Integer, LightUserData, Number, RegistryKey}; -pub use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; +pub use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMethods}; pub use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; #[cfg(feature = "async")] diff --git a/src/lua.rs b/src/lua.rs index a6a84a5..e06635c 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -21,7 +21,9 @@ use crate::types::{ Callback, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey, UserDataCell, }; -use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods, UserDataWrapped}; +use crate::userdata::{ + AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMethods, UserDataWrapped, +}; use crate::util::{ assert_stack, callback_error, check_stack, get_gc_userdata, get_main_state, get_userdata, get_wrapped_error, init_error_registry, init_gc_metatable_for, init_userdata_metatable, @@ -1523,37 +1525,86 @@ impl Lua { } let _sg = StackGuard::new(self.state); - assert_stack(self.state, 8); + assert_stack(self.state, 10); + let mut fields = StaticUserDataFields::default(); let mut methods = StaticUserDataMethods::default(); + T::add_fields(&mut fields); T::add_methods(&mut methods); + // Prepare metatable, add meta methods first and then meta fields protect_lua_closure(self.state, 0, 1, |state| { ffi::lua_newtable(state); })?; for (k, m) in methods.meta_methods { - push_string(self.state, k.name())?; + push_string(self.state, k.validate()?.name())?; self.push_value(Value::Function(self.create_callback(m)?))?; protect_lua_closure(self.state, 3, 1, |state| { ffi::lua_rawset(state, -3); })?; } + for (k, f) in fields.meta_fields { + push_string(self.state, k.validate()?.name())?; + self.push_value(f(self)?)?; + protect_lua_closure(self.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + let metatable_index = ffi::lua_absindex(self.state, -1); + + let mut extra_tables_count = 0; + + let mut field_getters_index = None; + let has_field_getters = fields.field_getters.len() > 0; + if has_field_getters { + protect_lua_closure(self.state, 0, 1, |state| { + ffi::lua_newtable(state); + })?; + for (k, m) in fields.field_getters { + push_string(self.state, &k)?; + self.push_value(Value::Function(self.create_callback(m)?))?; + + protect_lua_closure(self.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + field_getters_index = Some(ffi::lua_absindex(self.state, -1)); + extra_tables_count += 1; + } + + let mut field_setters_index = None; + let has_field_setters = fields.field_setters.len() > 0; + if has_field_setters { + protect_lua_closure(self.state, 0, 1, |state| { + ffi::lua_newtable(state); + })?; + for (k, m) in fields.field_setters { + push_string(self.state, &k)?; + self.push_value(Value::Function(self.create_callback(m)?))?; + + protect_lua_closure(self.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + field_setters_index = Some(ffi::lua_absindex(self.state, -1)); + extra_tables_count += 1; + } + + let mut methods_index = None; #[cfg(feature = "async")] - let no_methods = methods.methods.is_empty() && methods.async_methods.is_empty(); + let has_methods = methods.methods.len() > 0 || methods.async_methods.len() > 0; #[cfg(not(feature = "async"))] - let no_methods = methods.methods.is_empty(); - - if no_methods { - init_userdata_metatable::>(self.state, -1, None)?; - } else { + let has_methods = methods.methods.len() > 0; + if has_methods { protect_lua_closure(self.state, 0, 1, |state| { ffi::lua_newtable(state); })?; for (k, m) in methods.methods { push_string(self.state, &k)?; self.push_value(Value::Function(self.create_callback(m)?))?; + protect_lua_closure(self.state, 3, 1, |state| { ffi::lua_rawset(state, -3); })?; @@ -1562,15 +1613,26 @@ impl Lua { for (k, m) in methods.async_methods { push_string(self.state, &k)?; self.push_value(Value::Function(self.create_async_callback(m)?))?; + protect_lua_closure(self.state, 3, 1, |state| { ffi::lua_rawset(state, -3); })?; } - - init_userdata_metatable::>(self.state, -2, Some(-1))?; - ffi::lua_pop(self.state, 1); + methods_index = Some(ffi::lua_absindex(self.state, -1)); + extra_tables_count += 1; } + init_userdata_metatable::>( + self.state, + metatable_index, + field_getters_index, + field_setters_index, + methods_index, + )?; + + // Pop extra tables to get metatable on top of the stack + ffi::lua_pop(self.state, extra_tables_count); + let ptr = ffi::lua_topointer(self.state, -1); let id = protect_lua_closure(self.state, 1, 0, |state| { ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) @@ -2317,41 +2379,48 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet .push((name.as_ref().to_vec(), Self::box_async_function(function))); } - fn add_meta_method(&mut self, meta: MetaMethod, method: M) + fn add_meta_method(&mut self, meta: S, method: M) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result, { - self.meta_methods.push((meta, Self::box_method(method))); + self.meta_methods + .push((meta.into(), Self::box_method(method))); } - fn add_meta_method_mut(&mut self, meta: MetaMethod, method: M) + fn add_meta_method_mut(&mut self, meta: S, method: M) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result, { - self.meta_methods.push((meta, Self::box_method_mut(method))); + self.meta_methods + .push((meta.into(), Self::box_method_mut(method))); } - fn add_meta_function(&mut self, meta: MetaMethod, function: F) + fn add_meta_function(&mut self, meta: S, function: F) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, { - self.meta_methods.push((meta, Self::box_function(function))); + self.meta_methods + .push((meta.into(), Self::box_function(function))); } - fn add_meta_function_mut(&mut self, meta: MetaMethod, function: F) + fn add_meta_function_mut(&mut self, meta: S, function: F) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, { self.meta_methods - .push((meta, Self::box_function_mut(function))); + .push((meta.into(), Self::box_function_mut(function))); } } @@ -2473,3 +2542,104 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { }) } } + +struct StaticUserDataFields<'lua, T: 'static + UserData> { + field_getters: Vec<(Vec, Callback<'lua, 'static>)>, + field_setters: Vec<(Vec, Callback<'lua, 'static>)>, + meta_fields: Vec<( + MetaMethod, + Box Result> + 'static>, + )>, + _type: PhantomData, +} + +impl<'lua, T: 'static + UserData> Default for StaticUserDataFields<'lua, T> { + fn default() -> StaticUserDataFields<'lua, T> { + StaticUserDataFields { + field_getters: Vec::new(), + field_setters: Vec::new(), + meta_fields: Vec::new(), + _type: PhantomData, + } + } +} + +impl<'lua, T: 'static + UserData> UserDataFields<'lua, T> for StaticUserDataFields<'lua, T> { + fn add_field_method_get(&mut self, name: &S, method: M) + where + S: AsRef<[u8]> + ?Sized, + R: ToLua<'lua>, + M: 'static + MaybeSend + Fn(&'lua Lua, &T) -> Result, + { + self.field_getters.push(( + name.as_ref().to_vec(), + StaticUserDataMethods::box_method(move |lua, data, ()| method(lua, data)), + )); + } + + fn add_field_method_set(&mut self, name: &S, method: M) + where + S: AsRef<[u8]> + ?Sized, + A: FromLua<'lua>, + M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result<()>, + { + self.field_setters.push(( + name.as_ref().to_vec(), + StaticUserDataMethods::box_method_mut(method), + )); + } + + fn add_field_function_get(&mut self, name: &S, function: F) + where + S: AsRef<[u8]> + ?Sized, + R: ToLua<'lua>, + F: 'static + MaybeSend + Fn(&'lua Lua, AnyUserData<'lua>) -> Result, + { + self.field_getters.push(( + name.as_ref().to_vec(), + StaticUserDataMethods::::box_function(move |lua, data| function(lua, data)), + )); + } + + fn add_field_function_set(&mut self, name: &S, mut function: F) + where + S: AsRef<[u8]> + ?Sized, + A: FromLua<'lua>, + F: 'static + MaybeSend + FnMut(&'lua Lua, AnyUserData<'lua>, A) -> Result<()>, + { + self.field_setters.push(( + name.as_ref().to_vec(), + StaticUserDataMethods::::box_function_mut(move |lua, (data, val)| { + function(lua, data, val) + }), + )); + } + + fn add_meta_field_with(&mut self, meta: S, f: F) + where + S: Into, + R: ToLua<'lua>, + F: 'static + MaybeSend + Fn(&'lua Lua) -> Result, + { + let meta = meta.into(); + self.meta_fields.push(( + meta.clone(), + Box::new(move |lua| { + let value = f(lua)?.to_lua(lua)?; + if meta == MetaMethod::Index || meta == MetaMethod::NewIndex { + match value { + Value::Nil | Value::Table(_) | Value::Function(_) => {} + _ => { + return Err(Error::MetaMethodTypeError { + method: meta.to_string(), + type_name: value.type_name(), + message: Some("expected nil, table or function".to_string()), + }) + } + } + } + Ok(value) + }), + )); + } +} diff --git a/src/scope.rs b/src/scope.rs index bed6265..52809b5 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -13,12 +13,14 @@ use crate::ffi; use crate::function::Function; use crate::lua::Lua; use crate::types::{Callback, LuaRef, MaybeSend, UserDataCell}; -use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods, UserDataWrapped}; +use crate::userdata::{ + AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMethods, UserDataWrapped, +}; use crate::util::{ assert_stack, init_userdata_metatable, protect_lua_closure, push_string, push_userdata, take_userdata, StackGuard, }; -use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti, Value}; +use crate::value::{FromLua, FromLuaMulti, MultiValue, ToLua, ToLuaMulti, Value}; #[cfg(feature = "async")] use { @@ -304,7 +306,9 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { } } + let mut ud_fields = NonStaticUserDataFields::default(); let mut ud_methods = NonStaticUserDataMethods::default(); + T::add_fields(&mut ud_fields); T::add_methods(&mut ud_methods); unsafe { @@ -325,37 +329,90 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { })?; ffi::lua_setuservalue(lua.state, -2); + // Prepare metatable, add meta methods first and then meta fields protect_lua_closure(lua.state, 0, 1, move |state| { ffi::lua_newtable(state); })?; - for (k, m) in ud_methods.meta_methods { - push_string(lua.state, k.name())?; + push_string(lua.state, k.validate()?.name())?; lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?))?; protect_lua_closure(lua.state, 3, 1, |state| { ffi::lua_rawset(state, -3); })?; } + for (k, f) in ud_fields.meta_fields { + push_string(lua.state, k.validate()?.name())?; + lua.push_value(f(mem::transmute(lua))?)?; - if ud_methods.methods.is_empty() { - init_userdata_metatable::<()>(lua.state, -1, None)?; - } else { + protect_lua_closure(lua.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + let metatable_index = ffi::lua_absindex(lua.state, -1); + + let mut field_getters_index = None; + if ud_fields.field_getters.len() > 0 { + protect_lua_closure(lua.state, 0, 1, |state| { + ffi::lua_newtable(state); + })?; + for (k, m) in ud_fields.field_getters { + push_string(lua.state, &k)?; + lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?))?; + + protect_lua_closure(lua.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + field_getters_index = Some(ffi::lua_absindex(lua.state, -1)); + } + + let mut field_setters_index = None; + if ud_fields.field_setters.len() > 0 { + protect_lua_closure(lua.state, 0, 1, |state| { + ffi::lua_newtable(state); + })?; + for (k, m) in ud_fields.field_setters { + push_string(lua.state, &k)?; + lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?))?; + + protect_lua_closure(lua.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + field_setters_index = Some(ffi::lua_absindex(lua.state, -1)); + } + + let mut methods_index = None; + if ud_methods.methods.len() > 0 { + // Create table used for methods lookup protect_lua_closure(lua.state, 0, 1, |state| { ffi::lua_newtable(state); })?; for (k, m) in ud_methods.methods { push_string(lua.state, &k)?; lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?))?; + protect_lua_closure(lua.state, 3, 1, |state| { ffi::lua_rawset(state, -3); })?; } - - init_userdata_metatable::<()>(lua.state, -2, Some(-1))?; - ffi::lua_pop(lua.state, 1); + methods_index = Some(ffi::lua_absindex(lua.state, -1)); } + init_userdata_metatable::<()>( + lua.state, + metatable_index, + field_getters_index, + field_setters_index, + methods_index, + )?; + + let count = field_getters_index.map(|_| 1).unwrap_or(0) + + field_setters_index.map(|_| 1).unwrap_or(0) + + methods_index.map(|_| 1).unwrap_or(0); + ffi::lua_pop(lua.state, count); + let mt_id = ffi::lua_topointer(lua.state, -1); ffi::lua_setmetatable(lua.state, -2); @@ -604,59 +661,166 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> for NonStaticUserDataMethods<'l mlua_panic!("asynchronous functions are not supported for non-static userdata") } - fn add_meta_method(&mut self, meta: MetaMethod, method: M) + fn add_meta_method(&mut self, meta: S, method: M) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result, { self.meta_methods.push(( - meta, + meta.into(), NonStaticMethod::Method(Box::new(move |lua, ud, args| { method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) })), )); } - fn add_meta_method_mut(&mut self, meta: MetaMethod, mut method: M) + fn add_meta_method_mut(&mut self, meta: S, mut method: M) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result, { self.meta_methods.push(( - meta, + meta.into(), NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) })), )); } - fn add_meta_function(&mut self, meta: MetaMethod, function: F) + fn add_meta_function(&mut self, meta: S, function: F) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, { self.meta_methods.push(( - meta, + meta.into(), NonStaticMethod::Function(Box::new(move |lua, args| { function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) })), )); } - fn add_meta_function_mut(&mut self, meta: MetaMethod, mut function: F) + fn add_meta_function_mut(&mut self, meta: S, mut function: F) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, { self.meta_methods.push(( - meta, + meta.into(), NonStaticMethod::FunctionMut(Box::new(move |lua, args| { function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) })), )); } } + +struct NonStaticUserDataFields<'lua, T: UserData> { + field_getters: Vec<(Vec, NonStaticMethod<'lua, T>)>, + field_setters: Vec<(Vec, NonStaticMethod<'lua, T>)>, + meta_fields: Vec<(MetaMethod, Box Result>>)>, +} + +impl<'lua, T: UserData> Default for NonStaticUserDataFields<'lua, T> { + fn default() -> NonStaticUserDataFields<'lua, T> { + NonStaticUserDataFields { + field_getters: Vec::new(), + field_setters: Vec::new(), + meta_fields: Vec::new(), + } + } +} + +impl<'lua, T: UserData> UserDataFields<'lua, T> for NonStaticUserDataFields<'lua, T> { + fn add_field_method_get(&mut self, name: &S, method: M) + where + S: AsRef<[u8]> + ?Sized, + R: ToLua<'lua>, + M: 'static + MaybeSend + Fn(&'lua Lua, &T) -> Result, + { + self.field_getters.push(( + name.as_ref().to_vec(), + NonStaticMethod::Method(Box::new(move |lua, ud, _| { + method(lua, ud)?.to_lua_multi(lua) + })), + )); + } + + fn add_field_method_set(&mut self, name: &S, mut method: M) + where + S: AsRef<[u8]> + ?Sized, + A: FromLua<'lua>, + M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result<()>, + { + self.field_setters.push(( + name.as_ref().to_vec(), + NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { + method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + )); + } + + fn add_field_function_get(&mut self, name: &S, function: F) + where + S: AsRef<[u8]> + ?Sized, + R: ToLua<'lua>, + F: 'static + MaybeSend + Fn(&'lua Lua, AnyUserData<'lua>) -> Result, + { + self.field_getters.push(( + name.as_ref().to_vec(), + NonStaticMethod::Function(Box::new(move |lua, args| { + function(lua, AnyUserData::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + })), + )); + } + + fn add_field_function_set(&mut self, name: &S, mut function: F) + where + S: AsRef<[u8]> + ?Sized, + A: FromLua<'lua>, + F: 'static + MaybeSend + FnMut(&'lua Lua, AnyUserData<'lua>, A) -> Result<()>, + { + self.field_setters.push(( + name.as_ref().to_vec(), + NonStaticMethod::FunctionMut(Box::new(move |lua, args| { + let (ud, val) = <_>::from_lua_multi(args, lua)?; + function(lua, ud, val)?.to_lua_multi(lua) + })), + )); + } + + fn add_meta_field_with(&mut self, meta: S, f: F) + where + S: Into, + F: 'static + MaybeSend + Fn(&'lua Lua) -> Result, + R: ToLua<'lua>, + { + let meta = meta.into(); + self.meta_fields.push(( + meta.clone(), + Box::new(move |lua| { + let value = f(lua)?.to_lua(lua)?; + if meta == MetaMethod::Index || meta == MetaMethod::NewIndex { + match value { + Value::Nil | Value::Table(_) | Value::Function(_) => {} + _ => { + return Err(Error::MetaMethodTypeError { + method: meta.to_string(), + type_name: value.type_name(), + message: Some("expected nil, table or function".to_string()), + }) + } + } + } + Ok(value) + }), + )); + } +} diff --git a/src/userdata.rs b/src/userdata.rs index a6131a1..e4b4397 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -1,4 +1,7 @@ use std::cell::{Ref, RefMut}; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::string::String as StdString; #[cfg(feature = "async")] use std::future::Future; @@ -13,7 +16,7 @@ use crate::error::{Error, Result}; use crate::ffi; use crate::function::Function; use crate::lua::Lua; -use crate::table::Table; +use crate::table::{Table, TablePairs}; use crate::types::{LuaRef, MaybeSend, UserDataCell}; use crate::util::{assert_stack, get_destructed_userdata_metatable, get_userdata, StackGuard}; use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti, Value}; @@ -24,7 +27,7 @@ use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti, Value}; /// generally no need to do so: [`UserData`] implementors can instead just implement `Drop`. /// /// [`UserData`]: trait.UserData.html -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone)] pub enum MetaMethod { /// The `+` operator. Add, @@ -105,58 +108,147 @@ pub enum MetaMethod { /// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#3.3.8 #[cfg(any(feature = "lua54", doc))] Close, + /// A custom metamethod. + /// + /// Must not be in the protected list: `__gc`, `__metatable`. + Custom(StdString), +} + +impl PartialEq for MetaMethod { + fn eq(&self, other: &Self) -> bool { + self.name() == other.name() + } +} + +impl Eq for MetaMethod {} + +impl Hash for MetaMethod { + fn hash(&self, state: &mut H) { + self.name().hash(state); + } +} + +impl fmt::Display for MetaMethod { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}", self.name()) + } } impl MetaMethod { - pub(crate) fn name(self) -> &'static [u8] { + pub(crate) fn name(&self) -> &str { match self { - MetaMethod::Add => b"__add", - MetaMethod::Sub => b"__sub", - MetaMethod::Mul => b"__mul", - MetaMethod::Div => b"__div", - MetaMethod::Mod => b"__mod", - MetaMethod::Pow => b"__pow", - MetaMethod::Unm => b"__unm", + MetaMethod::Add => "__add", + MetaMethod::Sub => "__sub", + MetaMethod::Mul => "__mul", + MetaMethod::Div => "__div", + MetaMethod::Mod => "__mod", + MetaMethod::Pow => "__pow", + MetaMethod::Unm => "__unm", #[cfg(any(feature = "lua54", feature = "lua53"))] - MetaMethod::IDiv => b"__idiv", + MetaMethod::IDiv => "__idiv", #[cfg(any(feature = "lua54", feature = "lua53"))] - MetaMethod::BAnd => b"__band", + MetaMethod::BAnd => "__band", #[cfg(any(feature = "lua54", feature = "lua53"))] - MetaMethod::BOr => b"__bor", + MetaMethod::BOr => "__bor", #[cfg(any(feature = "lua54", feature = "lua53"))] - MetaMethod::BXor => b"__bxor", + MetaMethod::BXor => "__bxor", #[cfg(any(feature = "lua54", feature = "lua53"))] - MetaMethod::BNot => b"__bnot", + MetaMethod::BNot => "__bnot", #[cfg(any(feature = "lua54", feature = "lua53"))] - MetaMethod::Shl => b"__shl", + MetaMethod::Shl => "__shl", #[cfg(any(feature = "lua54", feature = "lua53"))] - MetaMethod::Shr => b"__shr", + MetaMethod::Shr => "__shr", - MetaMethod::Concat => b"__concat", - MetaMethod::Len => b"__len", - MetaMethod::Eq => b"__eq", - MetaMethod::Lt => b"__lt", - MetaMethod::Le => b"__le", - MetaMethod::Index => b"__index", - MetaMethod::NewIndex => b"__newindex", - MetaMethod::Call => b"__call", - MetaMethod::ToString => b"__tostring", + MetaMethod::Concat => "__concat", + MetaMethod::Len => "__len", + MetaMethod::Eq => "__eq", + MetaMethod::Lt => "__lt", + MetaMethod::Le => "__le", + MetaMethod::Index => "__index", + MetaMethod::NewIndex => "__newindex", + MetaMethod::Call => "__call", + MetaMethod::ToString => "__tostring", #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - MetaMethod::Pairs => b"__pairs", + MetaMethod::Pairs => "__pairs", #[cfg(feature = "lua54")] - MetaMethod::Close => b"__close", + MetaMethod::Close => "__close", + + MetaMethod::Custom(ref name) => name, } } + + pub(crate) fn validate(self) -> Result { + match self { + MetaMethod::Custom(name) if name == "__gc" => Err(Error::MetaMethodRestricted(name)), + MetaMethod::Custom(name) if name == "__metatable" => { + Err(Error::MetaMethodRestricted(name)) + } + _ => Ok(self), + } + } +} + +impl From for MetaMethod { + fn from(name: StdString) -> Self { + match name.as_str() { + "__add" => MetaMethod::Add, + "__sub" => MetaMethod::Sub, + "__mul" => MetaMethod::Mul, + "__div" => MetaMethod::Div, + "__mod" => MetaMethod::Mod, + "__pow" => MetaMethod::Pow, + "__unm" => MetaMethod::Unm, + + #[cfg(any(feature = "lua54", feature = "lua53"))] + "__idiv" => MetaMethod::IDiv, + #[cfg(any(feature = "lua54", feature = "lua53"))] + "__band" => MetaMethod::BAnd, + #[cfg(any(feature = "lua54", feature = "lua53"))] + "__bor" => MetaMethod::BOr, + #[cfg(any(feature = "lua54", feature = "lua53"))] + "__bxor" => MetaMethod::BXor, + #[cfg(any(feature = "lua54", feature = "lua53"))] + "__bnot" => MetaMethod::BNot, + #[cfg(any(feature = "lua54", feature = "lua53"))] + "__shl" => MetaMethod::Shl, + #[cfg(any(feature = "lua54", feature = "lua53"))] + "__shr" => MetaMethod::Shr, + + "__concat" => MetaMethod::Concat, + "__len" => MetaMethod::Len, + "__eq" => MetaMethod::Eq, + "__lt" => MetaMethod::Lt, + "__le" => MetaMethod::Le, + "__index" => MetaMethod::Index, + "__newindex" => MetaMethod::NewIndex, + "__call" => MetaMethod::Call, + "__tostring" => MetaMethod::ToString, + + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] + "__pairs" => MetaMethod::Pairs, + + #[cfg(feature = "lua54")] + "__close" => MetaMethod::Close, + + _ => MetaMethod::Custom(name), + } + } +} + +impl From<&str> for MetaMethod { + fn from(name: &str) -> Self { + MetaMethod::from(name.to_owned()) + } } /// Method registry for [`UserData`] implementors. /// /// [`UserData`]: trait.UserData.html pub trait UserDataMethods<'lua, T: UserData> { - /// Add a method which accepts a `&T` as the first parameter. + /// Add a regular method which accepts a `&T` as the first parameter. /// /// Regular methods are implemented by overriding the `__index` metamethod and returning the /// accessed method. This allows them to be used with the expected `userdata:method()` syntax. @@ -165,7 +257,7 @@ pub trait UserDataMethods<'lua, T: UserData> { /// be used as a fall-back if no regular method is found. fn add_method(&mut self, name: &S, method: M) where - S: ?Sized + AsRef<[u8]>, + S: AsRef<[u8]> + ?Sized, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result; @@ -177,7 +269,7 @@ pub trait UserDataMethods<'lua, T: UserData> { /// [`add_method`]: #method.add_method fn add_method_mut(&mut self, name: &S, method: M) where - S: ?Sized + AsRef<[u8]>, + S: AsRef<[u8]> + ?Sized, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result; @@ -195,24 +287,25 @@ pub trait UserDataMethods<'lua, T: UserData> { fn add_async_method(&mut self, name: &S, method: M) where T: Clone, - S: ?Sized + AsRef<[u8]>, + S: AsRef<[u8]> + ?Sized, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, MR: 'lua + Future>; /// Add a regular method as a function which accepts generic arguments, the first argument will - /// be a `UserData` of type T if the method is called with Lua method syntax: + /// be a [`AnyUserData`] of type `T` if the method is called with Lua method syntax: /// `my_userdata:my_method(arg1, arg2)`, or it is passed in as the first argument: /// `my_userdata.my_method(my_userdata, arg1, arg2)`. /// /// Prefer to use [`add_method`] or [`add_method_mut`] as they are easier to use. /// + /// [`AnyUserData`]: struct.AnyUserData.html /// [`add_method`]: #method.add_method /// [`add_method_mut`]: #method.add_method_mut fn add_function(&mut self, name: &S, function: F) where - S: ?Sized + AsRef<[u8]>, + S: AsRef<[u8]> + ?Sized, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result; @@ -224,7 +317,7 @@ pub trait UserDataMethods<'lua, T: UserData> { /// [`add_function`]: #method.add_function fn add_function_mut(&mut self, name: &S, function: F) where - S: ?Sized + AsRef<[u8]>, + S: AsRef<[u8]> + ?Sized, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result; @@ -242,7 +335,7 @@ pub trait UserDataMethods<'lua, T: UserData> { fn add_async_function(&mut self, name: &S, function: F) where T: Clone, - S: ?Sized + AsRef<[u8]>, + S: AsRef<[u8]> + ?Sized, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, @@ -256,8 +349,9 @@ pub trait UserDataMethods<'lua, T: UserData> { /// side has a metatable. To prevent this, use [`add_meta_function`]. /// /// [`add_meta_function`]: #method.add_meta_function - fn add_meta_method(&mut self, meta: MetaMethod, method: M) + fn add_meta_method(&mut self, meta: S, method: M) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result; @@ -270,8 +364,9 @@ pub trait UserDataMethods<'lua, T: UserData> { /// side has a metatable. To prevent this, use [`add_meta_function`]. /// /// [`add_meta_function`]: #method.add_meta_function - fn add_meta_method_mut(&mut self, meta: MetaMethod, method: M) + fn add_meta_method_mut(&mut self, meta: S, method: M) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result; @@ -281,8 +376,9 @@ pub trait UserDataMethods<'lua, T: UserData> { /// Metamethods for binary operators can be triggered if either the left or right argument to /// the binary operator has a metatable, so the first argument here is not necessarily a /// userdata of type `T`. - fn add_meta_function(&mut self, meta: MetaMethod, function: F) + fn add_meta_function(&mut self, meta: S, function: F) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result; @@ -292,13 +388,85 @@ pub trait UserDataMethods<'lua, T: UserData> { /// This is a version of [`add_meta_function`] that accepts a FnMut argument. /// /// [`add_meta_function`]: #method.add_meta_function - fn add_meta_function_mut(&mut self, meta: MetaMethod, function: F) + fn add_meta_function_mut(&mut self, meta: S, function: F) where + S: Into, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result; } +/// Field registry for [`UserData`] implementors. +/// +/// [`UserData`]: trait.UserData.html +pub trait UserDataFields<'lua, T: UserData> { + /// Add a regular field getter as a method which accepts a `&T` as the parameter. + /// + /// Regular field getters are implemented by overriding the `__index` metamethod and returning the + /// accessed field. This allows them to be used with the expected `userdata.field` syntax. + /// + /// If `add_meta_method` is used to set the `__index` metamethod, the `__index` metamethod will + /// be used as a fall-back if no regular field or method are found. + fn add_field_method_get(&mut self, name: &S, method: M) + where + S: AsRef<[u8]> + ?Sized, + R: ToLua<'lua>, + M: 'static + MaybeSend + Fn(&'lua Lua, &T) -> Result; + + /// Add a regular field setter as a method which accepts a `&mut T` as the first parameter. + /// + /// Regular field setters are implemented by overriding the `__newindex` metamethod and setting the + /// accessed field. This allows them to be used with the expected `userdata.field = value` syntax. + /// + /// If `add_meta_method` is used to set the `__newindex` metamethod, the `__newindex` metamethod will + /// be used as a fall-back if no regular field is found. + fn add_field_method_set(&mut self, name: &S, method: M) + where + S: AsRef<[u8]> + ?Sized, + A: FromLua<'lua>, + M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result<()>; + + /// Add a regular field getter as a function which accepts a generic [`AnyUserData`] of type `T` + /// argument. + /// + /// Prefer to use [`add_field_method_get`] as it is easier to use. + /// + /// [`AnyUserData`]: struct.AnyUserData.html + /// [`add_field_method_get`]: #method.add_field_method_get + fn add_field_function_get(&mut self, name: &S, function: F) + where + S: AsRef<[u8]> + ?Sized, + R: ToLua<'lua>, + F: 'static + MaybeSend + Fn(&'lua Lua, AnyUserData<'lua>) -> Result; + + /// Add a regular field setter as a function which accepts a generic [`AnyUserData`] of type `T` + /// first argument. + /// + /// Prefer to use [`add_field_method_set`] as it is easier to use. + /// + /// [`AnyUserData`]: struct.AnyUserData.html + /// [`add_field_method_set`]: #method.add_field_method_set + fn add_field_function_set(&mut self, name: &S, function: F) + where + S: AsRef<[u8]> + ?Sized, + A: FromLua<'lua>, + F: 'static + MaybeSend + FnMut(&'lua Lua, AnyUserData<'lua>, A) -> Result<()>; + + /// Add a metamethod value computed from `f`. + /// + /// This will initialize the metamethod value from `f` on `UserData` creation. + /// + /// # Note + /// + /// `mlua` will trigger an error on an attempt to define a protected metamethod, + /// like `__gc` or `__metatable`. + fn add_meta_field_with(&mut self, meta: S, f: F) + where + S: Into, + F: 'static + MaybeSend + Fn(&'lua Lua) -> Result, + R: ToLua<'lua>; +} + /// Trait for custom userdata types. /// /// By implementing this trait, a struct becomes eligible for use inside Lua code. Implementations @@ -322,21 +490,21 @@ pub trait UserDataMethods<'lua, T: UserData> { /// # } /// ``` /// -/// Custom methods and operators can be provided by implementing `add_methods` (refer to -/// [`UserDataMethods`] for more information): +/// Custom fields, methods and operators can be provided by implementing `add_fields` or `add_methods` +/// (refer to [`UserDataFields`] and [`UserDataMethods`] for more information): /// /// ``` -/// # use mlua::{Lua, MetaMethod, Result, UserData, UserDataMethods}; +/// # use mlua::{Lua, MetaMethod, Result, UserData, UserDataFields, UserDataMethods}; /// # fn main() -> Result<()> { /// # let lua = Lua::new(); /// struct MyUserData(i32); /// /// impl UserData for MyUserData { -/// fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { -/// methods.add_method("get", |_, this, _: ()| { -/// Ok(this.0) -/// }); +/// fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { +/// fields.add_field_method_get("val", |_, this| Ok(this.0)); +/// } /// +/// fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { /// methods.add_method_mut("add", |_, this, value: i32| { /// this.0 += value; /// Ok(()) @@ -351,9 +519,9 @@ pub trait UserDataMethods<'lua, T: UserData> { /// lua.globals().set("myobject", MyUserData(123))?; /// /// lua.load(r#" -/// assert(myobject:get() == 123) +/// assert(myobject.val == 123) /// myobject:add(7) -/// assert(myobject:get() == 130) +/// assert(myobject.val == 130) /// assert(myobject + 10 == 140) /// "#).exec()?; /// # Ok(()) @@ -362,8 +530,12 @@ pub trait UserDataMethods<'lua, T: UserData> { /// /// [`ToLua`]: trait.ToLua.html /// [`FromLua`]: trait.FromLua.html +/// [`UserDataFields`]: trait.UserDataFields.html /// [`UserDataMethods`]: trait.UserDataMethods.html pub trait UserData: Sized { + /// Adds custom fields specific to this userdata. + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(_fields: &mut F) {} + /// Adds custom methods and operators specific to this userdata. fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(_methods: &mut M) {} } @@ -537,11 +709,30 @@ impl<'lua> AnyUserData<'lua> { V::from_lua(res, lua) } - /// Checks for a metamethod in this `AnyUserData` + /// Returns a metatable of this `UserData`. + /// + /// Returned [`UserDataMetatable`] object wraps the original metatable and + /// allows to provide safe access to it methods. + /// + /// [`UserDataMetatable`]: struct.UserDataMetatable.html + pub fn get_metatable(&self) -> Result> { + self.get_raw_metatable().map(UserDataMetatable) + } + + /// Checks for a metamethod in this `AnyUserData`. + /// + /// This function is deprecated and will be removed in v0.7. + /// Please use [`get_metatable`] function instead. + /// + /// [`get_metatable`]: #method.get_metatable + #[deprecated( + since = "0.6.0", + note = "Please use the get_metatable function instead" + )] pub fn has_metamethod(&self, method: MetaMethod) -> Result { - match self.get_metatable() { + match self.get_raw_metatable() { Ok(mt) => { - let name = self.0.lua.create_string(method.name())?; + let name = self.0.lua.create_string(method.validate()?.name())?; if let Value::Nil = mt.raw_get(name)? { Ok(false) } else { @@ -553,7 +744,7 @@ impl<'lua> AnyUserData<'lua> { } } - fn get_metatable(&self) -> Result> { + fn get_raw_metatable(&self) -> Result> { unsafe { let lua = self.0.lua; let _sg = StackGuard::new(lua.state); @@ -571,12 +762,13 @@ impl<'lua> AnyUserData<'lua> { pub(crate) fn equals>(&self, other: T) -> Result { let other = other.as_ref(); + // Uses lua_rawequal() under the hood if self == other { return Ok(true); } - let mt = self.get_metatable()?; - if mt != other.get_metatable()? { + let mt = self.get_raw_metatable()?; + if mt != other.get_raw_metatable()? { return Ok(false); } @@ -640,6 +832,80 @@ impl<'lua> AsRef> for AnyUserData<'lua> { } } +/// Handle to a `UserData` metatable. +#[derive(Clone, Debug)] +pub struct UserDataMetatable<'lua>(pub(crate) Table<'lua>); + +impl<'lua> UserDataMetatable<'lua> { + /// Gets the value associated to `key` from the metatable. + /// + /// If no value is associated to `key`, returns the `Nil` value. + /// Access to restricted metamethods such as `__gc` or `__metatable` will cause an error. + pub fn get, V: FromLua<'lua>>(&self, key: K) -> Result { + self.0.raw_get(key.into().validate()?.name()) + } + + /// Sets a key-value pair in the metatable. + /// + /// If the value is `Nil`, this will effectively remove the `key`. + /// Access to restricted metamethods such as `__gc` or `__metatable` will cause an error. + /// Setting `__index` or `__newindex` metamethods is also restricted because their values are cached + /// for `mlua` internal usage. + pub fn set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { + let key = key.into().validate()?; + // `__index` and `__newindex` cannot be changed in runtime, because values are cached + if key == MetaMethod::Index || key == MetaMethod::NewIndex { + return Err(Error::MetaMethodRestricted(key.to_string())); + } + self.0.raw_set(key.name(), value) + } + + /// Checks whether the metatable contains a non-nil value for `key`. + pub fn contains>(&self, key: K) -> Result { + self.0.contains_key(key.into().validate()?.name()) + } + + /// Consumes this metatable and returns an iterator over the pairs of the metatable. + /// + /// The pairs are wrapped in a [`Result`], since they are lazily converted to `V` type. + /// + /// [`Result`]: type.Result.html + pub fn pairs, V: FromLua<'lua>>(self) -> UserDataMetatablePairs<'lua, V> { + UserDataMetatablePairs(self.0.pairs()) + } +} + +/// An iterator over the pairs of a [`UserData`] metatable. +/// +/// It skips restricted metamethods, such as `__gc` or `__metatable`. +/// +/// This struct is created by the [`UserDataMetatable::pairs`] method. +/// +/// [`UserData`]: trait.UserData.html +/// [`UserDataMetatable::pairs`]: struct.UserDataMetatable.html#method.pairs +pub struct UserDataMetatablePairs<'lua, V>(TablePairs<'lua, StdString, V>); + +impl<'lua, V> Iterator for UserDataMetatablePairs<'lua, V> +where + V: FromLua<'lua>, +{ + type Item = Result<(MetaMethod, V)>; + + fn next(&mut self) -> Option { + loop { + match self.0.next()? { + Ok((key, value)) => { + // Skip restricted metamethods + if let Ok(metamethod) = MetaMethod::from(key).validate() { + break Some(Ok((metamethod, value))); + } + } + Err(e) => break Some(Err(e)), + } + } + } +} + #[cfg(feature = "serialize")] impl<'lua> Serialize for AnyUserData<'lua> { fn serialize(&self, serializer: S) -> StdResult diff --git a/src/util.rs b/src/util.rs index b2c2b61..8b29a91 100644 --- a/src/util.rs +++ b/src/util.rs @@ -293,56 +293,149 @@ pub unsafe fn get_gc_userdata(state: *mut ffi::lua_State, index: c_int) } // Populates the given table with the appropriate members to be a userdata metatable for the given -// type. This function takes the given table at the `metatable` index, and adds an appropriate __gc -// member to it for the given type and a __metatable entry to protect the table from script access. -// The function also, if given a `members` table index, will set up an __index metamethod to return -// the appropriate member on __index. Additionally, if there is already an __index entry on the -// given metatable, instead of simply overwriting the __index, instead the created __index method -// will capture the previous one, and use it as a fallback only if the given key is not found in the -// provided members table. Internally uses 6 stack spaces and does not call checkstack. +// type. This function takes the given table at the `metatable` index, and adds an appropriate `__gc` +// member to it for the given type and a `__metatable` entry to protect the table from script access. +// The function also, if given a `field_getters` or `methods` tables, will create an `__index` metamethod +// (capturing previous one) to lookup in `field_getters` first, then `methods` and falling back to the +// captured `__index` if no matches found. +// The same is also applicable for `__newindex` metamethod and `field_setters` table. +// Internally uses 8 stack spaces and does not call checkstack. pub unsafe fn init_userdata_metatable( state: *mut ffi::lua_State, metatable: c_int, - members: Option, + field_getters: Option, + field_setters: Option, + methods: Option, ) -> Result<()> { - // Used if both an __index metamethod is set and regular methods, checks methods table - // first, then __index metamethod. + // Wrapper to lookup in `field_getters` first, then `methods`, ending original `__index`. + // Used only if `field_getters` or `methods` set. unsafe extern "C" fn meta_index_impl(state: *mut ffi::lua_State) -> c_int { + // stack: self, key ffi::luaL_checkstack(state, 2, ptr::null()); - ffi::lua_pushvalue(state, -1); - ffi::lua_gettable(state, ffi::lua_upvalueindex(2)); - if ffi::lua_isnil(state, -1) == 0 { - ffi::lua_insert(state, -3); - ffi::lua_pop(state, 2); - 1 - } else { - ffi::lua_pop(state, 1); - ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); - ffi::lua_insert(state, -3); - ffi::lua_call(state, 2, 1); - 1 + // lookup in `field_getters` table + if ffi::lua_isnil(state, ffi::lua_upvalueindex(2)) == 0 { + ffi::lua_pushvalue(state, -1); // `key` arg + if ffi::lua_rawget(state, ffi::lua_upvalueindex(2)) != ffi::LUA_TNIL { + ffi::lua_insert(state, -3); // move function + ffi::lua_pop(state, 1); // remove `key` + ffi::lua_call(state, 1, 1); + return 1; + } + ffi::lua_pop(state, 1); // pop the nil value } + // lookup in `methods` table + if ffi::lua_isnil(state, ffi::lua_upvalueindex(3)) == 0 { + ffi::lua_pushvalue(state, -1); // `key` arg + if ffi::lua_rawget(state, ffi::lua_upvalueindex(3)) != ffi::LUA_TNIL { + ffi::lua_insert(state, -3); + ffi::lua_pop(state, 2); + return 1; + } + ffi::lua_pop(state, 1); // pop the nil value + } + + // lookup in `__index` + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); + match ffi::lua_type(state, -1) { + ffi::LUA_TNIL => { + ffi::lua_pop(state, 1); // pop the nil value + let field = ffi::lua_tostring(state, -1); + ffi::luaL_error(state, cstr!("attempt to get an unknown field '%s'"), field); + } + ffi::LUA_TTABLE => { + ffi::lua_insert(state, -2); + ffi::lua_gettable(state, -2); + } + ffi::LUA_TFUNCTION => { + ffi::lua_insert(state, -3); + ffi::lua_call(state, 2, 1); + } + _ => unreachable!(), + } + + 1 + } + + // Similar to `meta_index_impl`, checks `field_setters` table first, then `__newindex` metamethod. + // Used only if `field_setters` set. + unsafe extern "C" fn meta_newindex_impl(state: *mut ffi::lua_State) -> c_int { + // stack: self, key, value + ffi::luaL_checkstack(state, 2, ptr::null()); + + // lookup in `field_setters` table + ffi::lua_pushvalue(state, -2); // `key` arg + if ffi::lua_rawget(state, ffi::lua_upvalueindex(2)) != ffi::LUA_TNIL { + ffi::lua_remove(state, -3); // remove `key` + ffi::lua_insert(state, -3); // move function + ffi::lua_call(state, 2, 0); + return 0; + } + ffi::lua_pop(state, 1); // pop the nil value + + // lookup in `__newindex` + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); + match ffi::lua_type(state, -1) { + ffi::LUA_TNIL => { + ffi::lua_pop(state, 1); // pop the nil value + let field = ffi::lua_tostring(state, -2); + ffi::luaL_error(state, cstr!("attempt to set an unknown field '%s'"), field); + } + ffi::LUA_TTABLE => { + ffi::lua_insert(state, -3); + ffi::lua_settable(state, -3); + } + ffi::LUA_TFUNCTION => { + ffi::lua_insert(state, -4); + ffi::lua_call(state, 3, 0); + } + _ => unreachable!(), + } + + 0 } - let members = members.map(|i| ffi::lua_absindex(state, i)); ffi::lua_pushvalue(state, metatable); - if let Some(members) = members { + if field_getters.is_some() || methods.is_some() { push_string(state, "__index")?; - ffi::lua_pushvalue(state, -1); + ffi::lua_pushvalue(state, -1); let index_type = ffi::lua_rawget(state, -3); - if index_type == ffi::LUA_TNIL { - ffi::lua_pop(state, 1); - ffi::lua_pushvalue(state, members); - } else if index_type == ffi::LUA_TFUNCTION { - ffi::lua_pushvalue(state, members); - protect_lua_closure(state, 2, 1, |state| { - ffi::lua_pushcclosure(state, meta_index_impl, 2); - })?; - } else { - mlua_panic!("improper __index type {}", index_type); + match index_type { + ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { + for &idx in &[field_getters, methods] { + if let Some(idx) = idx { + ffi::lua_pushvalue(state, idx); + } else { + ffi::lua_pushnil(state); + } + } + protect_lua_closure(state, 3, 1, |state| { + ffi::lua_pushcclosure(state, meta_index_impl, 3); + })?; + } + _ => mlua_panic!("improper __index type {}", index_type), + } + + protect_lua_closure(state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } + + if let Some(field_setters) = field_setters { + push_string(state, "__newindex")?; + + ffi::lua_pushvalue(state, -1); + let newindex_type = ffi::lua_rawget(state, -3); + match newindex_type { + ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { + ffi::lua_pushvalue(state, field_setters); + protect_lua_closure(state, 2, 1, |state| { + ffi::lua_pushcclosure(state, meta_newindex_impl, 2); + })?; + } + _ => mlua_panic!("improper __newindex type {}", newindex_type), } protect_lua_closure(state, 3, 1, |state| { diff --git a/tests/scope.rs b/tests/scope.rs index fe7e4f4..95591be 100644 --- a/tests/scope.rs +++ b/tests/scope.rs @@ -2,7 +2,8 @@ use std::cell::Cell; use std::rc::Rc; use mlua::{ - AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataMethods, + AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataFields, + UserDataMethods, }; #[test] @@ -139,6 +140,41 @@ fn outer_lua_access() -> Result<()> { Ok(()) } +#[test] +fn scope_userdata_fields() -> Result<()> { + struct MyUserData<'a>(&'a Cell); + + impl<'a> UserData for MyUserData<'a> { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("val", |_, data| Ok(data.0.get())); + fields.add_field_method_set("val", |_, data, val| { + data.0.set(val); + Ok(()) + }); + } + } + + let lua = Lua::new(); + + let i = Cell::new(42); + let f: Function = lua + .load( + r#" + function(u) + assert(u.val == 42) + u.val = 44 + end + "#, + ) + .eval()?; + + lua.scope(|scope| f.call::<_, ()>(scope.create_nonstatic_userdata(MyUserData(&i))?))?; + + assert_eq!(i.get(), 44); + + Ok(()) +} + #[test] fn scope_userdata_methods() -> Result<()> { struct MyUserData<'a>(&'a Cell); diff --git a/tests/userdata.rs b/tests/userdata.rs index 08159c5..3c45a6f 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use std::sync::atomic::{AtomicI64, Ordering}; use mlua::{ - AnyUserData, ExternalError, Function, Lua, MetaMethod, Result, String, UserData, - UserDataMethods, Value, + AnyUserData, Error, ExternalError, Function, Lua, MetaMethod, Nil, Result, String, UserData, + UserDataFields, UserDataMethods, Value, }; #[test] @@ -155,10 +155,10 @@ fn test_metamethods() -> Result<()> { assert!(userdata2.equals(userdata3)?); let userdata1: AnyUserData = globals.get("userdata1")?; - assert!(userdata1.has_metamethod(MetaMethod::Add)?); - assert!(userdata1.has_metamethod(MetaMethod::Sub)?); - assert!(userdata1.has_metamethod(MetaMethod::Index)?); - assert!(!userdata1.has_metamethod(MetaMethod::Pow)?); + assert!(userdata1.get_metatable()?.contains(MetaMethod::Add)?); + assert!(userdata1.get_metatable()?.contains(MetaMethod::Sub)?); + assert!(userdata1.get_metatable()?.contains(MetaMethod::Index)?); + assert!(!userdata1.get_metatable()?.contains(MetaMethod::Pow)?); Ok(()) } @@ -250,7 +250,7 @@ fn test_gc_userdata() -> Result<()> { } #[test] -fn detroys_userdata() -> Result<()> { +fn test_destroy_userdata() -> Result<()> { struct MyUserdata(Arc<()>); impl UserData for MyUserdata {} @@ -272,7 +272,7 @@ fn detroys_userdata() -> Result<()> { } #[test] -fn user_value() -> Result<()> { +fn test_user_value() -> Result<()> { struct MyUserData; impl UserData for MyUserData {} @@ -335,3 +335,101 @@ fn test_functions() -> Result<()> { Ok(()) } + +#[test] +fn test_fields() -> Result<()> { + #[derive(Copy, Clone)] + struct MyUserData(i64); + + impl UserData for MyUserData { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("val", |_, data| Ok(data.0)); + fields.add_field_method_set("val", |_, data, val| { + data.0 = val; + Ok(()) + }); + + fields.add_meta_field_with(MetaMethod::Index, |lua| { + let index = lua.create_table()?; + index.set("f", 321)?; + Ok(index) + }); + } + } + + let lua = Lua::new(); + let globals = lua.globals(); + globals.set("ud", MyUserData(7))?; + lua.load( + r#" + assert(ud.val == 7) + ud.val = 10 + assert(ud.val == 10) + assert(ud.f == 321) + "#, + ) + .exec()?; + + Ok(()) +} + +#[test] +fn test_metatable() -> Result<()> { + #[derive(Copy, Clone)] + struct MyUserData(i64); + + impl UserData for MyUserData { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_meta_field_with("__type_name", |_| Ok("MyUserData")); + } + + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_function("my_type_name", |_, data: AnyUserData| { + let metatable = data.get_metatable()?; + metatable.get::<_, String>("__type_name") + }); + } + } + + let lua = Lua::new(); + let globals = lua.globals(); + globals.set("ud", MyUserData(7))?; + lua.load( + r#" + assert(ud:my_type_name() == "MyUserData") + "#, + ) + .exec()?; + + let ud: AnyUserData = globals.get("ud")?; + let metatable = ud.get_metatable()?; + + match metatable.get::<_, Value>("__gc") { + Ok(_) => panic!("expected MetaMethodRestricted, got no error"), + Err(Error::MetaMethodRestricted(_)) => {} + Err(e) => panic!("expected MetaMethodRestricted, got {:?}", e), + } + + match metatable.set(MetaMethod::Index, Nil) { + Ok(_) => panic!("expected MetaMethodRestricted, got no error"), + Err(Error::MetaMethodRestricted(_)) => {} + Err(e) => panic!("expected MetaMethodRestricted, got {:?}", e), + } + + #[derive(Copy, Clone)] + struct MyUserData2(i64); + + impl UserData for MyUserData2 { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_meta_field_with("__index", |_| Ok(1)); + } + } + + match lua.create_userdata(MyUserData2(1)) { + Ok(_) => panic!("expected MetaMethodTypeError, got no error"), + Err(Error::MetaMethodTypeError { .. }) => {} + Err(e) => panic!("expected MetaMethodTypeError, got {:?}", e), + } + + Ok(()) +}