diff --git a/src/function.rs b/src/function.rs index 1b7e6fe..141fcbe 100644 --- a/src/function.rs +++ b/src/function.rs @@ -160,3 +160,9 @@ impl<'lua> Function<'lua> { } } } + +impl<'lua> PartialEq for Function<'lua> { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} diff --git a/src/table.rs b/src/table.rs index 0c3e143..bb45ea5 100644 --- a/src/table.rs +++ b/src/table.rs @@ -166,6 +166,62 @@ impl<'lua> Table<'lua> { self.get::<_, Function>(key)?.call(args) } + /// Compares two tables for equality. + /// + /// Tables are compared by reference first. + /// If they are not primitively equals, then mlua will try to invoke the `__eq` metamethod. + /// mlua will check `self` first for the metamethod, then `other` if not found. + /// + /// # Examples + /// + /// Compare two tables using `__eq` metamethod: + /// + /// ``` + /// # use mlua::{Lua, Result, Table}; + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// let table1 = lua.create_table()?; + /// table1.set(1, "value")?; + /// + /// let table2 = lua.create_table()?; + /// table2.set(2, "value")?; + /// + /// let always_equals_mt = lua.create_table()?; + /// always_equals_mt.set("__eq", lua.create_function(|_, (_t1, _t2): (Table, Table)| Ok(true))?)?; + /// table2.set_metatable(Some(always_equals_mt)); + /// + /// assert!(table1.equals(&table1.clone())?); + /// assert!(table1.equals(&table2)?); + /// # Ok(()) + /// # } + /// ``` + pub fn equals>(&self, other: T) -> Result { + let other = other.as_ref(); + if self == other { + return Ok(true); + } + + // Compare using __eq metamethod if exists + // First, check the self for the metamethod. + // If self does not define it, then check the other table. + if let Some(mt) = self.get_metatable() { + if mt.contains_key("__eq")? { + return mt + .get::<_, Function>("__eq")? + .call((self.clone(), other.clone())); + } + } + if let Some(mt) = other.get_metatable() { + if mt.contains_key("__eq")? { + return mt + .get::<_, Function>("__eq")? + .call((self.clone(), other.clone())); + } + } + + Ok(false) + } + /// Removes a key from the table, returning the value at the key /// if the key was previously in the table. pub fn raw_remove>(&self, key: K) -> Result<()> { @@ -368,6 +424,19 @@ impl<'lua> Table<'lua> { } } +impl<'lua> PartialEq for Table<'lua> { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl<'lua> AsRef> for Table<'lua> { + #[inline] + fn as_ref(&self) -> &Self { + self + } +} + /// An iterator over the pairs of a Lua table. /// /// This struct is created by the [`Table::pairs`] method. diff --git a/src/thread.rs b/src/thread.rs index 97b0362..2452fec 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -143,3 +143,9 @@ impl<'lua> Thread<'lua> { } } } + +impl<'lua> PartialEq for Thread<'lua> { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} diff --git a/src/types.rs b/src/types.rs index 29e55b1..e88d0c1 100644 --- a/src/types.rs +++ b/src/types.rs @@ -5,6 +5,7 @@ use std::{fmt, mem, ptr}; use crate::error::Result; use crate::ffi; use crate::lua::Lua; +use crate::util::{assert_stack, StackGuard}; use crate::value::MultiValue; /// Type of Lua integer numbers. @@ -92,3 +93,16 @@ impl<'lua> Drop for LuaRef<'lua> { self.lua.drop_ref(self) } } + +impl<'lua> PartialEq for LuaRef<'lua> { + fn eq(&self, other: &Self) -> bool { + let lua = self.lua; + unsafe { + let _sg = StackGuard::new(lua.state); + assert_stack(lua.state, 2); + lua.push_ref(&self); + lua.push_ref(&other); + ffi::lua_rawequal(lua.state, -1, -2) == 1 + } + } +} diff --git a/src/userdata.rs b/src/userdata.rs index 8177ad9..32817b6 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -2,7 +2,9 @@ use std::cell::{Ref, RefCell, RefMut}; use crate::error::{Error, Result}; use crate::ffi; +use crate::function::Function; use crate::lua::Lua; +use crate::table::Table; use crate::types::LuaRef; use crate::util::{assert_stack, get_userdata, StackGuard}; use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti}; @@ -398,6 +400,42 @@ impl<'lua> AnyUserData<'lua> { V::from_lua(res, lua) } + fn get_metatable(&self) -> Result> { + unsafe { + let lua = self.0.lua; + let _sg = StackGuard::new(lua.state); + assert_stack(lua.state, 3); + + lua.push_ref(&self.0); + + if ffi::lua_getmetatable(lua.state, -1) == 0 { + return Err(Error::UserDataTypeMismatch); + } + + Ok(Table(lua.pop_ref())) + } + } + + pub(crate) fn equals>(&self, other: T) -> Result { + let other = other.as_ref(); + if self == other { + return Ok(true); + } + + let mt = self.get_metatable()?; + if mt != other.get_metatable()? { + return Ok(false); + } + + if mt.contains_key("__eq")? { + return mt + .get::<_, Function>("__eq")? + .call((self.clone(), other.clone())); + } + + Ok(false) + } + fn inspect<'a, T, R, F>(&'a self, func: F) -> Result where T: 'static + UserData, @@ -428,3 +466,16 @@ impl<'lua> AnyUserData<'lua> { } } } + +impl<'lua> PartialEq for AnyUserData<'lua> { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl<'lua> AsRef> for AnyUserData<'lua> { + #[inline] + fn as_ref(&self) -> &Self { + self + } +} diff --git a/src/value.rs b/src/value.rs index c3eab89..1ab5399 100644 --- a/src/value.rs +++ b/src/value.rs @@ -2,6 +2,7 @@ use std::iter::{self, FromIterator}; use std::{slice, str, vec}; use crate::error::{Error, Result}; +use crate::ffi; use crate::function::Function; use crate::lua::Lua; use crate::string::String; @@ -61,6 +62,51 @@ impl<'lua> Value<'lua> { Value::UserData(_) | Value::Error(_) => "userdata", } } + + /// Compares two values for equality. + /// + /// Equality comparisons do not convert strings to numbers or vice versa. + /// Tables, Functions, Threads, and Userdata are compared by reference: + /// two objects are considered equal only if they are the same object. + /// + /// If Tables or Userdata have `__eq` metamethod then mlua will try to invoke it. + /// The first value is checked first. If that value does not define a metamethod + /// for `__eq`, then mlua will check the second value. + /// Then mlua calls the metamethod with the two values as arguments, if found. + pub fn equals>(&self, other: T) -> Result { + match (self, other.as_ref()) { + (Value::Table(a), Value::Table(b)) => a.equals(b), + (Value::UserData(a), Value::UserData(b)) => a.equals(b), + _ => Ok(self == other.as_ref()), + } + } +} + +impl<'lua> PartialEq for Value<'lua> { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Value::Nil, Value::Nil) => true, + (Value::Boolean(a), Value::Boolean(b)) => a == b, + (Value::LightUserData(a), Value::LightUserData(b)) => a == b, + (Value::Integer(a), Value::Integer(b)) => *a == *b, + (Value::Integer(a), Value::Number(b)) => *a as ffi::lua_Number == *b, + (Value::Number(a), Value::Integer(b)) => *a == *b as ffi::lua_Number, + (Value::Number(a), Value::Number(b)) => *a == *b, + (Value::String(a), Value::String(b)) => a == b, + (Value::Table(a), Value::Table(b)) => a == b, + (Value::Function(a), Value::Function(b)) => a == b, + (Value::Thread(a), Value::Thread(b)) => a == b, + (Value::UserData(a), Value::UserData(b)) => a == b, + _ => false, + } + } +} + +impl<'lua> AsRef> for Value<'lua> { + #[inline] + fn as_ref(&self) -> &Self { + self + } } /// Trait for types convertible to `Value`. diff --git a/tests/memory.rs b/tests/memory.rs index 7c141cd..0bcdeb6 100644 --- a/tests/memory.rs +++ b/tests/memory.rs @@ -51,15 +51,15 @@ fn test_gc_error() { match lua .load( r#" - val = nil - table = {} - setmetatable(table, { - __gc = function() - error("gcwascalled") - end - }) - table = nil - collectgarbage("collect") + val = nil + table = {} + setmetatable(table, { + __gc = function() + error("gcwascalled") + end + }) + table = nil + collectgarbage("collect") "#, ) .exec() diff --git a/tests/table.rs b/tests/table.rs index 14ccb87..1d788e2 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -148,6 +148,40 @@ fn test_metatable() -> Result<()> { Ok(()) } +#[test] +fn test_table_eq() -> Result<()> { + let lua = Lua::new(); + let globals = lua.globals(); + + lua.load( + r#" + table1 = {1} + table2 = {1} + table3 = table1 + table4 = {1} + + setmetatable(table4, { + __eq = function(a, b) return a[1] == b[1] end + }) + "#, + ) + .exec()?; + + let table1 = globals.get::<_, Table>("table1")?; + let table2 = globals.get::<_, Table>("table2")?; + let table3 = globals.get::<_, Table>("table3")?; + let table4 = globals.get::<_, Table>("table4")?; + + assert!(table1 != table2); + assert!(!table1.equals(&table2)?); + assert!(table1 == table3); + assert!(table1.equals(&table3)?); + assert!(table1 != table4); + assert!(table1.equals(&table4)?); + + Ok(()) +} + #[test] fn test_table_error() -> Result<()> { let lua = Lua::new(); diff --git a/tests/tests.rs b/tests/tests.rs index 2e71dc1..8ef9d48 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -92,13 +92,13 @@ fn test_lua_multi() -> Result<()> { lua.load( r#" - function concat(arg1, arg2) - return arg1 .. arg2 - end + function concat(arg1, arg2) + return arg1 .. arg2 + end - function mreturn() - return 1, 2, 3, 4, 5, 6 - end + function mreturn() + return 1, 2, 3, 4, 5, 6 + end "#, ) .exec()?; diff --git a/tests/thread.rs b/tests/thread.rs index 3344b7c..d3266ba 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -20,13 +20,13 @@ fn test_thread() -> Result<()> { let thread = lua.create_thread( lua.load( r#" - function (s) - local sum = s - for i = 1,4 do - sum = sum + coroutine.yield(sum) - end - return sum + function (s) + local sum = s + for i = 1,4 do + sum = sum + coroutine.yield(sum) end + return sum + end "#, ) .eval()?, @@ -47,11 +47,11 @@ fn test_thread() -> Result<()> { let accumulate = lua.create_thread( lua.load( r#" - function (sum) - while true do - sum = sum + coroutine.yield(sum) - end + function (sum) + while true do + sum = sum + coroutine.yield(sum) end + end "#, ) .eval::()?, diff --git a/tests/userdata.rs b/tests/userdata.rs index 4f01669..34aa8b4 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use mlua::{ AnyUserData, ExternalError, Function, Lua, MetaMethod, Result, String, UserData, - UserDataMethods, + UserDataMethods, Value, }; #[test] @@ -96,6 +96,9 @@ fn test_metamethods() -> Result<()> { MetaMethod::Sub, |_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 - rhs.0)), ); + methods.add_meta_function(MetaMethod::Eq, |_, (lhs, rhs): (MyUserData, MyUserData)| { + Ok(lhs.0 == rhs.0) + }); methods.add_meta_method(MetaMethod::Index, |_, data, index: String| { if index.to_str()? == "inner" { Ok(data.0) @@ -122,6 +125,7 @@ fn test_metamethods() -> Result<()> { let globals = lua.globals(); globals.set("userdata1", MyUserData(7))?; globals.set("userdata2", MyUserData(3))?; + globals.set("userdata3", MyUserData(3))?; assert_eq!( lua.load("userdata1 + userdata2").eval::()?.0, 10 @@ -151,6 +155,13 @@ fn test_metamethods() -> Result<()> { assert_eq!(ipairs_it.call::<_, i64>(())?, 28); assert!(lua.load("userdata2.nonexist_field").eval::<()>().is_err()); + let userdata2: Value = globals.get("userdata2")?; + let userdata3: Value = globals.get("userdata3")?; + + assert!(lua.load("userdata2 == userdata3").eval::()?); + assert!(userdata2 != userdata3); // because references are differ + assert!(userdata2.equals(userdata3)?); + Ok(()) } @@ -175,18 +186,18 @@ fn test_gc_userdata() -> Result<()> { assert!(lua .load( r#" - local tbl = setmetatable({ - userdata = userdata - }, { __gc = function(self) - -- resurrect userdata - hatch = self.userdata - end }) + local tbl = setmetatable({ + userdata = userdata + }, { __gc = function(self) + -- resurrect userdata + hatch = self.userdata + end }) - tbl = nil - userdata = nil -- make table and userdata collectable - collectgarbage("collect") - hatch:access() - "# + tbl = nil + userdata = nil -- make table and userdata collectable + collectgarbage("collect") + hatch:access() + "# ) .exec() .is_err()); diff --git a/tests/value.rs b/tests/value.rs new file mode 100644 index 0000000..9beb68c --- /dev/null +++ b/tests/value.rs @@ -0,0 +1,57 @@ +use mlua::{Lua, Result, Value}; + +#[test] +fn test_value_eq() -> Result<()> { + let lua = Lua::new(); + let globals = lua.globals(); + + lua.load( + r#" + table1 = {1} + table2 = {1} + string1 = "hello" + string2 = "hello" + num1 = 1 + num2 = 1.0 + num3 = "1" + func1 = function() end + func2 = func1 + func3 = function() end + thread1 = coroutine.create(function() end) + thread2 = thread1 + + setmetatable(table1, { + __eq = function(a, b) return a[1] == b[1] end + }) + "#, + ) + .exec()?; + + let table1: Value = globals.get("table1")?; + let table2: Value = globals.get("table2")?; + let string1: Value = globals.get("string1")?; + let string2: Value = globals.get("string2")?; + let num1: Value = globals.get("num1")?; + let num2: Value = globals.get("num2")?; + let num3: Value = globals.get("num3")?; + let func1: Value = globals.get("func1")?; + let func2: Value = globals.get("func2")?; + let func3: Value = globals.get("func3")?; + let thread1: Value = globals.get("thread1")?; + let thread2: Value = globals.get("thread2")?; + + assert!(table1 != table2); + assert!(table1.equals(table2)?); + assert!(string1 == string2); + assert!(string1.equals(string2)?); + assert!(num1 == num2); + assert!(num1.equals(num2)?); + assert!(num1 != num3); + assert!(func1 == func2); + assert!(func1 != func3); + assert!(!func1.equals(func3)?); + assert!(thread1 == thread2); + assert!(thread1.equals(thread2)?); + + Ok(()) +}