diff --git a/src/conversion.rs b/src/conversion.rs index 6da46c7..a7265c0 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::convert::TryInto; use std::ffi::{CStr, CString}; use std::hash::{BuildHasher, Hash}; use std::string::String as StdString; @@ -451,37 +452,36 @@ where } } -macro_rules! lua_convert_array { - ($($N:literal)+) => { - $( - impl<'lua, T> ToLua<'lua> for [T; $N] - where - T: Clone + ToLua<'lua>, - { - fn to_lua(self, lua: &'lua Lua) -> Result> { - (&self).to_lua(lua) - } - } - - impl<'lua, T> ToLua<'lua> for &[T; $N] - where - T: Clone + ToLua<'lua>, - { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::Table( - lua.create_sequence_from(self.iter().cloned())?, - )) - } - } - )+ +impl<'lua, T, const N: usize> ToLua<'lua> for [T; N] +where + T: ToLua<'lua>, +{ + fn to_lua(self, lua: &'lua Lua) -> Result> { + Ok(Value::Table(lua.create_sequence_from(self)?)) } } -lua_convert_array! { - 0 1 2 3 4 5 6 7 8 9 - 10 11 12 13 14 15 16 17 18 19 - 20 21 22 23 24 25 26 27 28 29 - 30 31 32 +impl<'lua, T, const N: usize> FromLua<'lua> for [T; N] +where + T: FromLua<'lua>, +{ + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { + if let Value::Table(table) = value { + let vec = table.sequence_values().collect::>>()?; + vec.try_into() + .map_err(|vec: Vec| Error::FromLuaConversionError { + from: "Table", + to: "Array", + message: Some(format!("expected table of length {}, got {}", N, vec.len())), + }) + } else { + Err(Error::FromLuaConversionError { + from: value.type_name(), + to: "Array", + message: Some("expected table".to_string()), + }) + } + } } impl<'lua, T: ToLua<'lua>> ToLua<'lua> for Box<[T]> { diff --git a/tests/conversion.rs b/tests/conversion.rs index 01880d8..6d17d0a 100644 --- a/tests/conversion.rs +++ b/tests/conversion.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::ffi::{CStr, CString}; use maplit::{btreemap, btreeset, hashmap, hashset}; -use mlua::{Lua, Result}; +use mlua::{Error, Lua, Result}; #[test] fn test_conv_vec() -> Result<()> { @@ -123,3 +123,18 @@ fn test_conv_boxed_slice() -> Result<()> { Ok(()) } + +#[test] +fn test_conv_array() -> Result<()> { + let lua = Lua::new(); + + let v = [1, 2, 3]; + lua.globals().set("v", v)?; + let v2: [i32; 3] = lua.globals().get("v")?; + assert_eq!(v, v2); + + let v2 = lua.globals().get::<_, [i32; 4]>("v"); + assert!(matches!(v2, Err(Error::FromLuaConversionError { .. }))); + + Ok(()) +} diff --git a/tests/table.rs b/tests/table.rs index 6378e5c..b22696d 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -139,14 +139,6 @@ fn test_table_sequence_from() -> Result<()> { vec![1, 2, 3] ); - assert_eq!( - get_table - .call::<_, Table>(&[1, 2, 3])? - .sequence_values() - .collect::>>()?, - vec![1, 2, 3] - ); - Ok(()) }