From 2a8c5c7f82e1b8cc2f9c0ef572756e5c16ae90b1 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 15 May 2022 01:15:31 +0100 Subject: [PATCH] Refactor `Function::bind` implementation. Make it possible to bind async function arguments. Fixes #161 --- src/function.rs | 45 ++++++++++++++++++++++++--------------------- tests/async.rs | 5 ++++- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/function.rs b/src/function.rs index d1878d7..8e415d7 100644 --- a/src/function.rs +++ b/src/function.rs @@ -178,24 +178,17 @@ impl<'lua> Function<'lua> { /// # } /// ``` pub fn bind>(&self, args: A) -> Result> { - unsafe extern "C" fn bind_call_impl(state: *mut ffi::lua_State) -> c_int { + unsafe extern "C" fn args_wrapper_impl(state: *mut ffi::lua_State) -> c_int { let nargs = ffi::lua_gettop(state); - let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(2)) as c_int; - ffi::luaL_checkstack(state, nbinds + 2, ptr::null()); - - ffi::lua_settop(state, nargs + nbinds + 1); - ffi::lua_rotate(state, -(nargs + nbinds + 1), nbinds + 1); - - ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); - ffi::lua_replace(state, 1); + let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(1)) as c_int; + ffi::luaL_checkstack(state, nbinds, ptr::null()); for i in 0..nbinds { - ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 3)); - ffi::lua_replace(state, i + 2); + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 2)); } + ffi::lua_rotate(state, 1, nbinds); - ffi::lua_call(state, nargs + nbinds, ffi::LUA_MULTRET); - ffi::lua_gettop(state) + nargs + nbinds } let lua = self.0.lua; @@ -203,25 +196,35 @@ impl<'lua> Function<'lua> { let args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; - if nargs + 2 > ffi::LUA_MAX_UPVALUES { + if nargs + 1 > ffi::LUA_MAX_UPVALUES { return Err(Error::BindError); } - unsafe { + let args_wrapper = unsafe { let _sg = StackGuard::new(lua.state); - check_stack(lua.state, nargs + 5)?; + check_stack(lua.state, nargs + 3)?; - lua.push_ref(&self.0); ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer); for arg in args { lua.push_value(arg)?; } - protect_lua!(lua.state, nargs + 2, 1, fn(state) { - ffi::lua_pushcclosure(state, bind_call_impl, ffi::lua_gettop(state)); + protect_lua!(lua.state, nargs + 1, 1, fn(state) { + ffi::lua_pushcclosure(state, args_wrapper_impl, ffi::lua_gettop(state)); })?; - Ok(Function(lua.pop_ref())) - } + Function(lua.pop_ref()) + }; + + lua.load( + r#" + local func, args_wrapper = ... + return function(...) + return func(args_wrapper(...)) + end + "#, + ) + .set_name("_mlua_bind")? + .call((self.clone(), args_wrapper)) } /// Returns information about the function. diff --git a/tests/async.rs b/tests/async.rs index 9200da6..fcfb9c7 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -75,7 +75,10 @@ async fn test_async_call() -> Result<()> { async fn test_async_bind_call() -> Result<()> { let lua = Lua::new(); - let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { Ok(a + b) })?; + let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { + tokio::task::yield_now().await; + Ok(a + b) + })?; let plus_10 = sum.bind(10)?; lua.globals().set("plus_10", plus_10)?;