From 7cb9c4f39ca3e58e4a135dcad2aa60c5620d9a7e Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Wed, 3 Mar 2021 22:32:22 +0000 Subject: [PATCH] Fix bug in returning nil-prefixed multi values from async function --- src/lua.rs | 18 +++++++++++------- src/table.rs | 10 ++++++---- tests/async.rs | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/lua.rs b/src/lua.rs index d4972e8..a157c06 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -1741,11 +1741,14 @@ impl Lua { Ok(2) } Poll::Ready(results) => { - let results = lua.create_sequence_from(results?)?; - check_stack(state, 2)?; + let results = results?; + let nresults = results.len() as Integer; + let results = lua.create_sequence_from(results)?; + check_stack(state, 3)?; ffi::lua_pushboolean(state, 1); lua.push_value(Value::Table(results))?; - Ok(2) + lua.push_value(Value::Integer(nresults))?; + Ok(3) } } }) @@ -1772,9 +1775,10 @@ impl Lua { env.set("yield", coroutine.get::<_, Function>("yield")?)?; env.set( "unpack", - self.create_function(|_, tbl: Table| { + self.create_function(|_, (tbl, len): (Table, Integer)| { Ok(MultiValue::from_vec( - tbl.sequence_values().collect::>>()?, + tbl.raw_sequence_values_by_len(Some(len)) + .collect::>>()?, )) })?, )?; @@ -1785,9 +1789,9 @@ impl Lua { poll = get_poll(...) local poll, yield, unpack = poll, yield, unpack while true do - ready, res = poll() + local ready, res, nres = poll() if ready then - return unpack(res) + return unpack(res, nres) end yield(res) end diff --git a/src/table.rs b/src/table.rs index f294faf..9c57ca8 100644 --- a/src/table.rs +++ b/src/table.rs @@ -474,9 +474,11 @@ impl<'lua> Table<'lua> { } } - #[cfg(feature = "serialize")] - pub(crate) fn raw_sequence_values_by_len>(self) -> TableSequence<'lua, V> { - let len = self.raw_len(); + pub(crate) fn raw_sequence_values_by_len>( + self, + len: Option, + ) -> TableSequence<'lua, V> { + let len = len.unwrap_or_else(|| self.raw_len()); TableSequence { table: self.0, index: Some(1), @@ -641,7 +643,7 @@ impl<'lua> Serialize for Table<'lua> { let len = self.raw_len() as usize; if len > 0 || self.is_array() { let mut seq = serializer.serialize_seq(Some(len))?; - for v in self.clone().raw_sequence_values_by_len::() { + for v in self.clone().raw_sequence_values_by_len::(None) { let v = v.map_err(serde::ser::Error::custom)?; seq.serialize_element(&v)?; } diff --git a/tests/async.rs b/tests/async.rs index e2496f1..7d87b49 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -136,6 +136,24 @@ async fn test_async_handle_yield() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_async_multi_return_nil() -> Result<()> { + let lua = Lua::new(); + lua.globals().set( + "func", + lua.create_async_function(|_, _: ()| async { Ok((Option::::None, "error")) })?, + )?; + + lua.load( + r#" + local ok, err = func() + assert(err == "error") + "#, + ) + .exec_async() + .await +} + #[tokio::test] async fn test_async_return_async_closure() -> Result<()> { let lua = Lua::new();