Refactor `Function::bind` implementation.

Make it possible to bind async function arguments.
Fixes #161
This commit is contained in:
Alex Orlenko 2022-05-15 01:15:31 +01:00
parent 6b2ceb60c4
commit 2a8c5c7f82
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
2 changed files with 28 additions and 22 deletions

View File

@ -178,24 +178,17 @@ impl<'lua> Function<'lua> {
/// # } /// # }
/// ``` /// ```
pub fn bind<A: ToLuaMulti<'lua>>(&self, args: A) -> Result<Function<'lua>> { pub fn bind<A: ToLuaMulti<'lua>>(&self, args: A) -> Result<Function<'lua>> {
unsafe extern "C" fn bind_call_impl(state: *mut ffi::lua_State) -> c_int { unsafe extern "C" fn args_wrapper_impl(state: *mut ffi::lua_State) -> c_int {
let nargs = ffi::lua_gettop(state); let nargs = ffi::lua_gettop(state);
let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(2)) as c_int; let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(1)) as c_int;
ffi::luaL_checkstack(state, nbinds + 2, ptr::null()); ffi::luaL_checkstack(state, nbinds, ptr::null());
ffi::lua_settop(state, nargs + nbinds + 1);
ffi::lua_rotate(state, -(nargs + nbinds + 1), nbinds + 1);
ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1));
ffi::lua_replace(state, 1);
for i in 0..nbinds { for i in 0..nbinds {
ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 3)); ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 2));
ffi::lua_replace(state, i + 2);
} }
ffi::lua_rotate(state, 1, nbinds);
ffi::lua_call(state, nargs + nbinds, ffi::LUA_MULTRET); nargs + nbinds
ffi::lua_gettop(state)
} }
let lua = self.0.lua; let lua = self.0.lua;
@ -203,25 +196,35 @@ impl<'lua> Function<'lua> {
let args = args.to_lua_multi(lua)?; let args = args.to_lua_multi(lua)?;
let nargs = args.len() as c_int; let nargs = args.len() as c_int;
if nargs + 2 > ffi::LUA_MAX_UPVALUES { if nargs + 1 > ffi::LUA_MAX_UPVALUES {
return Err(Error::BindError); return Err(Error::BindError);
} }
unsafe { let args_wrapper = unsafe {
let _sg = StackGuard::new(lua.state); let _sg = StackGuard::new(lua.state);
check_stack(lua.state, nargs + 5)?; check_stack(lua.state, nargs + 3)?;
lua.push_ref(&self.0);
ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer); ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer);
for arg in args { for arg in args {
lua.push_value(arg)?; lua.push_value(arg)?;
} }
protect_lua!(lua.state, nargs + 2, 1, fn(state) { protect_lua!(lua.state, nargs + 1, 1, fn(state) {
ffi::lua_pushcclosure(state, bind_call_impl, ffi::lua_gettop(state)); ffi::lua_pushcclosure(state, args_wrapper_impl, ffi::lua_gettop(state));
})?; })?;
Ok(Function(lua.pop_ref())) Function(lua.pop_ref())
} };
lua.load(
r#"
local func, args_wrapper = ...
return function(...)
return func(args_wrapper(...))
end
"#,
)
.set_name("_mlua_bind")?
.call((self.clone(), args_wrapper))
} }
/// Returns information about the function. /// Returns information about the function.

View File

@ -75,7 +75,10 @@ async fn test_async_call() -> Result<()> {
async fn test_async_bind_call() -> Result<()> { async fn test_async_bind_call() -> Result<()> {
let lua = Lua::new(); let lua = Lua::new();
let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { Ok(a + b) })?; let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move {
tokio::task::yield_now().await;
Ok(a + b)
})?;
let plus_10 = sum.bind(10)?; let plus_10 = sum.bind(10)?;
lua.globals().set("plus_10", plus_10)?; lua.globals().set("plus_10", plus_10)?;