Implement PartialEq trait for Value (and subtypes)
Add equals() method to compare values optionally invoking __eq.
This commit is contained in:
parent
831161bfda
commit
5eec0ef56b
|
@ -160,3 +160,9 @@ impl<'lua> Function<'lua> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> PartialEq for Function<'lua> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
|
69
src/table.rs
69
src/table.rs
|
@ -166,6 +166,62 @@ impl<'lua> Table<'lua> {
|
|||
self.get::<_, Function>(key)?.call(args)
|
||||
}
|
||||
|
||||
/// Compares two tables for equality.
|
||||
///
|
||||
/// Tables are compared by reference first.
|
||||
/// If they are not primitively equals, then mlua will try to invoke the `__eq` metamethod.
|
||||
/// mlua will check `self` first for the metamethod, then `other` if not found.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Compare two tables using `__eq` metamethod:
|
||||
///
|
||||
/// ```
|
||||
/// # use mlua::{Lua, Result, Table};
|
||||
/// # fn main() -> Result<()> {
|
||||
/// # let lua = Lua::new();
|
||||
/// let table1 = lua.create_table()?;
|
||||
/// table1.set(1, "value")?;
|
||||
///
|
||||
/// let table2 = lua.create_table()?;
|
||||
/// table2.set(2, "value")?;
|
||||
///
|
||||
/// let always_equals_mt = lua.create_table()?;
|
||||
/// always_equals_mt.set("__eq", lua.create_function(|_, (_t1, _t2): (Table, Table)| Ok(true))?)?;
|
||||
/// table2.set_metatable(Some(always_equals_mt));
|
||||
///
|
||||
/// assert!(table1.equals(&table1.clone())?);
|
||||
/// assert!(table1.equals(&table2)?);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
|
||||
let other = other.as_ref();
|
||||
if self == other {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Compare using __eq metamethod if exists
|
||||
// First, check the self for the metamethod.
|
||||
// If self does not define it, then check the other table.
|
||||
if let Some(mt) = self.get_metatable() {
|
||||
if mt.contains_key("__eq")? {
|
||||
return mt
|
||||
.get::<_, Function>("__eq")?
|
||||
.call((self.clone(), other.clone()));
|
||||
}
|
||||
}
|
||||
if let Some(mt) = other.get_metatable() {
|
||||
if mt.contains_key("__eq")? {
|
||||
return mt
|
||||
.get::<_, Function>("__eq")?
|
||||
.call((self.clone(), other.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Removes a key from the table, returning the value at the key
|
||||
/// if the key was previously in the table.
|
||||
pub fn raw_remove<K: ToLua<'lua>>(&self, key: K) -> Result<()> {
|
||||
|
@ -368,6 +424,19 @@ impl<'lua> Table<'lua> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'lua> PartialEq for Table<'lua> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> AsRef<Table<'lua>> for Table<'lua> {
|
||||
#[inline]
|
||||
fn as_ref(&self) -> &Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over the pairs of a Lua table.
|
||||
///
|
||||
/// This struct is created by the [`Table::pairs`] method.
|
||||
|
|
|
@ -143,3 +143,9 @@ impl<'lua> Thread<'lua> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> PartialEq for Thread<'lua> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
|
14
src/types.rs
14
src/types.rs
|
@ -5,6 +5,7 @@ use std::{fmt, mem, ptr};
|
|||
use crate::error::Result;
|
||||
use crate::ffi;
|
||||
use crate::lua::Lua;
|
||||
use crate::util::{assert_stack, StackGuard};
|
||||
use crate::value::MultiValue;
|
||||
|
||||
/// Type of Lua integer numbers.
|
||||
|
@ -92,3 +93,16 @@ impl<'lua> Drop for LuaRef<'lua> {
|
|||
self.lua.drop_ref(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> PartialEq for LuaRef<'lua> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let lua = self.lua;
|
||||
unsafe {
|
||||
let _sg = StackGuard::new(lua.state);
|
||||
assert_stack(lua.state, 2);
|
||||
lua.push_ref(&self);
|
||||
lua.push_ref(&other);
|
||||
ffi::lua_rawequal(lua.state, -1, -2) == 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,9 @@ use std::cell::{Ref, RefCell, RefMut};
|
|||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::ffi;
|
||||
use crate::function::Function;
|
||||
use crate::lua::Lua;
|
||||
use crate::table::Table;
|
||||
use crate::types::LuaRef;
|
||||
use crate::util::{assert_stack, get_userdata, StackGuard};
|
||||
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti};
|
||||
|
@ -398,6 +400,42 @@ impl<'lua> AnyUserData<'lua> {
|
|||
V::from_lua(res, lua)
|
||||
}
|
||||
|
||||
fn get_metatable(&self) -> Result<Table<'lua>> {
|
||||
unsafe {
|
||||
let lua = self.0.lua;
|
||||
let _sg = StackGuard::new(lua.state);
|
||||
assert_stack(lua.state, 3);
|
||||
|
||||
lua.push_ref(&self.0);
|
||||
|
||||
if ffi::lua_getmetatable(lua.state, -1) == 0 {
|
||||
return Err(Error::UserDataTypeMismatch);
|
||||
}
|
||||
|
||||
Ok(Table(lua.pop_ref()))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
|
||||
let other = other.as_ref();
|
||||
if self == other {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mt = self.get_metatable()?;
|
||||
if mt != other.get_metatable()? {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if mt.contains_key("__eq")? {
|
||||
return mt
|
||||
.get::<_, Function>("__eq")?
|
||||
.call((self.clone(), other.clone()));
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn inspect<'a, T, R, F>(&'a self, func: F) -> Result<R>
|
||||
where
|
||||
T: 'static + UserData,
|
||||
|
@ -428,3 +466,16 @@ impl<'lua> AnyUserData<'lua> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> PartialEq for AnyUserData<'lua> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> AsRef<AnyUserData<'lua>> for AnyUserData<'lua> {
|
||||
#[inline]
|
||||
fn as_ref(&self) -> &Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
46
src/value.rs
46
src/value.rs
|
@ -2,6 +2,7 @@ use std::iter::{self, FromIterator};
|
|||
use std::{slice, str, vec};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::ffi;
|
||||
use crate::function::Function;
|
||||
use crate::lua::Lua;
|
||||
use crate::string::String;
|
||||
|
@ -61,6 +62,51 @@ impl<'lua> Value<'lua> {
|
|||
Value::UserData(_) | Value::Error(_) => "userdata",
|
||||
}
|
||||
}
|
||||
|
||||
/// Compares two values for equality.
|
||||
///
|
||||
/// Equality comparisons do not convert strings to numbers or vice versa.
|
||||
/// Tables, Functions, Threads, and Userdata are compared by reference:
|
||||
/// two objects are considered equal only if they are the same object.
|
||||
///
|
||||
/// If Tables or Userdata have `__eq` metamethod then mlua will try to invoke it.
|
||||
/// The first value is checked first. If that value does not define a metamethod
|
||||
/// for `__eq`, then mlua will check the second value.
|
||||
/// Then mlua calls the metamethod with the two values as arguments, if found.
|
||||
pub fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
|
||||
match (self, other.as_ref()) {
|
||||
(Value::Table(a), Value::Table(b)) => a.equals(b),
|
||||
(Value::UserData(a), Value::UserData(b)) => a.equals(b),
|
||||
_ => Ok(self == other.as_ref()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> PartialEq for Value<'lua> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Value::Nil, Value::Nil) => true,
|
||||
(Value::Boolean(a), Value::Boolean(b)) => a == b,
|
||||
(Value::LightUserData(a), Value::LightUserData(b)) => a == b,
|
||||
(Value::Integer(a), Value::Integer(b)) => *a == *b,
|
||||
(Value::Integer(a), Value::Number(b)) => *a as ffi::lua_Number == *b,
|
||||
(Value::Number(a), Value::Integer(b)) => *a == *b as ffi::lua_Number,
|
||||
(Value::Number(a), Value::Number(b)) => *a == *b,
|
||||
(Value::String(a), Value::String(b)) => a == b,
|
||||
(Value::Table(a), Value::Table(b)) => a == b,
|
||||
(Value::Function(a), Value::Function(b)) => a == b,
|
||||
(Value::Thread(a), Value::Thread(b)) => a == b,
|
||||
(Value::UserData(a), Value::UserData(b)) => a == b,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> AsRef<Value<'lua>> for Value<'lua> {
|
||||
#[inline]
|
||||
fn as_ref(&self) -> &Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for types convertible to `Value`.
|
||||
|
|
|
@ -148,6 +148,40 @@ fn test_metatable() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_eq() -> Result<()> {
|
||||
let lua = Lua::new();
|
||||
let globals = lua.globals();
|
||||
|
||||
lua.load(
|
||||
r#"
|
||||
table1 = {1}
|
||||
table2 = {1}
|
||||
table3 = table1
|
||||
table4 = {1}
|
||||
|
||||
setmetatable(table4, {
|
||||
__eq = function(a, b) return a[1] == b[1] end
|
||||
})
|
||||
"#,
|
||||
)
|
||||
.exec()?;
|
||||
|
||||
let table1 = globals.get::<_, Table>("table1")?;
|
||||
let table2 = globals.get::<_, Table>("table2")?;
|
||||
let table3 = globals.get::<_, Table>("table3")?;
|
||||
let table4 = globals.get::<_, Table>("table4")?;
|
||||
|
||||
assert!(table1 != table2);
|
||||
assert!(!table1.equals(&table2)?);
|
||||
assert!(table1 == table3);
|
||||
assert!(table1.equals(&table3)?);
|
||||
assert!(table1 != table4);
|
||||
assert!(table1.equals(&table4)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_error() -> Result<()> {
|
||||
let lua = Lua::new();
|
||||
|
|
|
@ -13,7 +13,7 @@ use std::sync::Arc;
|
|||
|
||||
use mlua::{
|
||||
AnyUserData, ExternalError, Function, Lua, MetaMethod, Result, String, UserData,
|
||||
UserDataMethods,
|
||||
UserDataMethods, Value,
|
||||
};
|
||||
|
||||
#[test]
|
||||
|
@ -96,6 +96,9 @@ fn test_metamethods() -> Result<()> {
|
|||
MetaMethod::Sub,
|
||||
|_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 - rhs.0)),
|
||||
);
|
||||
methods.add_meta_function(MetaMethod::Eq, |_, (lhs, rhs): (MyUserData, MyUserData)| {
|
||||
Ok(lhs.0 == rhs.0)
|
||||
});
|
||||
methods.add_meta_method(MetaMethod::Index, |_, data, index: String| {
|
||||
if index.to_str()? == "inner" {
|
||||
Ok(data.0)
|
||||
|
@ -122,6 +125,7 @@ fn test_metamethods() -> Result<()> {
|
|||
let globals = lua.globals();
|
||||
globals.set("userdata1", MyUserData(7))?;
|
||||
globals.set("userdata2", MyUserData(3))?;
|
||||
globals.set("userdata3", MyUserData(3))?;
|
||||
assert_eq!(
|
||||
lua.load("userdata1 + userdata2").eval::<MyUserData>()?.0,
|
||||
10
|
||||
|
@ -151,6 +155,13 @@ fn test_metamethods() -> Result<()> {
|
|||
assert_eq!(ipairs_it.call::<_, i64>(())?, 28);
|
||||
assert!(lua.load("userdata2.nonexist_field").eval::<()>().is_err());
|
||||
|
||||
let userdata2: Value = globals.get("userdata2")?;
|
||||
let userdata3: Value = globals.get("userdata3")?;
|
||||
|
||||
assert!(lua.load("userdata2 == userdata3").eval::<bool>()?);
|
||||
assert!(userdata2 != userdata3); // because references are differ
|
||||
assert!(userdata2.equals(userdata3)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
use mlua::{Lua, Result, Value};
|
||||
|
||||
#[test]
|
||||
fn test_value_eq() -> Result<()> {
|
||||
let lua = Lua::new();
|
||||
let globals = lua.globals();
|
||||
|
||||
lua.load(
|
||||
r#"
|
||||
table1 = {1}
|
||||
table2 = {1}
|
||||
string1 = "hello"
|
||||
string2 = "hello"
|
||||
num1 = 1
|
||||
num2 = 1.0
|
||||
num3 = "1"
|
||||
func1 = function() end
|
||||
func2 = func1
|
||||
func3 = function() end
|
||||
thread1 = coroutine.create(function() end)
|
||||
thread2 = thread1
|
||||
|
||||
setmetatable(table1, {
|
||||
__eq = function(a, b) return a[1] == b[1] end
|
||||
})
|
||||
"#,
|
||||
)
|
||||
.exec()?;
|
||||
|
||||
let table1: Value = globals.get("table1")?;
|
||||
let table2: Value = globals.get("table2")?;
|
||||
let string1: Value = globals.get("string1")?;
|
||||
let string2: Value = globals.get("string2")?;
|
||||
let num1: Value = globals.get("num1")?;
|
||||
let num2: Value = globals.get("num2")?;
|
||||
let num3: Value = globals.get("num3")?;
|
||||
let func1: Value = globals.get("func1")?;
|
||||
let func2: Value = globals.get("func2")?;
|
||||
let func3: Value = globals.get("func3")?;
|
||||
let thread1: Value = globals.get("thread1")?;
|
||||
let thread2: Value = globals.get("thread2")?;
|
||||
|
||||
assert!(table1 != table2);
|
||||
assert!(table1.equals(table2)?);
|
||||
assert!(string1 == string2);
|
||||
assert!(string1.equals(string2)?);
|
||||
assert!(num1 == num2);
|
||||
assert!(num1.equals(num2)?);
|
||||
assert!(num1 != num3);
|
||||
assert!(func1 == func2);
|
||||
assert!(func1 != func3);
|
||||
assert!(!func1.equals(func3)?);
|
||||
assert!(thread1 == thread2);
|
||||
assert!(thread1.equals(thread2)?);
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue