From 1d4a135e8e14bfbf476924257568e6f4e9d9c470 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Thu, 22 Dec 2022 16:24:35 +0000 Subject: [PATCH] Add `Function::wrap`/`Function::wrap_mut`/`Function::wrap_async` to wrap functions into a type that implements `IntoLua` trait. This is useful to avoid calling `lua.create_function*` every time when `Function` handle is needed. --- src/conversion.rs | 25 ++++++++++++++++- src/function.rs | 69 +++++++++++++++++++++++++++++++++++++++++++++++ tests/async.rs | 14 ++++++++++ tests/function.rs | 33 +++++++++++++++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) diff --git a/src/conversion.rs b/src/conversion.rs index cc9227c..c4ff6c2 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -19,7 +19,14 @@ use crate::userdata::{AnyUserData, UserData}; use crate::value::{FromLua, IntoLua, Nil, Value}; #[cfg(feature = "unstable")] -use crate::{function::OwnedFunction, table::OwnedTable, userdata::OwnedAnyUserData}; +use crate::{ + function::{OwnedFunction, WrappedFunction}, + table::OwnedTable, + userdata::OwnedAnyUserData, +}; + +#[cfg(all(feature = "async", feature = "unstable"))] +use crate::function::WrappedAsyncFunction; impl<'lua> IntoLua<'lua> for Value<'lua> { #[inline] @@ -129,6 +136,22 @@ impl<'lua> FromLua<'lua> for OwnedFunction { } } +#[cfg(feature = "unstable")] +impl<'lua> IntoLua<'lua> for WrappedFunction<'lua> { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + lua.create_callback(self.0).map(Value::Function) + } +} + +#[cfg(all(feature = "async", feature = "unstable"))] +impl<'lua> IntoLua<'lua> for WrappedAsyncFunction<'lua> { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + lua.create_async_callback(self.0).map(Value::Function) + } +} + impl<'lua> IntoLua<'lua> for Thread<'lua> { #[inline] fn into_lua(self, _: &'lua Lua) -> Result> { diff --git a/src/function.rs b/src/function.rs index d282a50..3a0aa36 100644 --- a/src/function.rs +++ b/src/function.rs @@ -11,9 +11,20 @@ use crate::util::{ }; use crate::value::{FromLuaMulti, IntoLuaMulti}; +#[cfg(feature = "unstable")] +use { + crate::lua::Lua, + crate::types::{Callback, MaybeSend}, + crate::value::IntoLua, + std::cell::RefCell, +}; + #[cfg(feature = "async")] use {futures_core::future::LocalBoxFuture, futures_util::future}; +#[cfg(all(feature = "async", feature = "unstable"))] +use {crate::types::AsyncCallback, futures_core::Future, futures_util::TryFutureExt}; + /// Handle to an internal Lua function. #[derive(Clone, Debug)] pub struct Function<'lua>(pub(crate) LuaRef<'lua>); @@ -408,6 +419,64 @@ impl<'lua> PartialEq for Function<'lua> { } } +#[cfg(feature = "unstable")] +pub(crate) struct WrappedFunction<'lua>(pub(crate) Callback<'lua, 'static>); + +#[cfg(all(feature = "async", feature = "unstable"))] +pub(crate) struct WrappedAsyncFunction<'lua>(pub(crate) AsyncCallback<'lua, 'static>); + +#[cfg(feature = "unstable")] +#[cfg_attr(docsrs, doc(cfg(feature = "unstable")))] +impl<'lua> Function<'lua> { + /// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`] trait. + #[inline] + pub fn wrap(func: F) -> impl IntoLua<'lua> + where + F: Fn(&'lua Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + R: IntoLuaMulti<'lua>, + { + WrappedFunction(Box::new(move |lua, args| { + func(lua, A::from_lua_multi(args, lua)?)?.into_lua_multi(lua) + })) + } + + /// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait. + #[inline] + pub fn wrap_mut(func: F) -> impl IntoLua<'lua> + where + F: FnMut(&'lua Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + R: IntoLuaMulti<'lua>, + { + let func = RefCell::new(func); + WrappedFunction(Box::new(move |lua, args| { + let mut func = func + .try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?; + func(lua, A::from_lua_multi(args, lua)?)?.into_lua_multi(lua) + })) + } + + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub fn wrap_async(func: F) -> impl IntoLua<'lua> + where + F: Fn(&'lua Lua, A) -> FR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + FR: Future> + 'lua, + R: IntoLuaMulti<'lua>, + { + WrappedAsyncFunction(Box::new(move |lua, args| { + let args = match A::from_lua_multi(args, lua) { + Ok(args) => args, + Err(e) => return Box::pin(future::err(e)), + }; + Box::pin(func(lua, args).and_then(move |ret| future::ready(ret.into_lua_multi(lua)))) + })) + } +} + #[cfg(test)] mod assertions { use super::*; diff --git a/tests/async.rs b/tests/async.rs index fc7a631..48f3cb6 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -30,6 +30,20 @@ async fn test_async_function() -> Result<()> { Ok(()) } +#[cfg(feature = "unstable")] +#[tokio::test] +async fn test_async_function_wrap() -> Result<()> { + let lua = Lua::new(); + + let f = Function::wrap_async(|_, s: String| async move { Ok(s) }); + lua.globals().set("f", f)?; + + let res: String = lua.load(r#"f("hello")"#).eval_async().await?; + assert_eq!(res, "hello"); + + Ok(()) +} + #[tokio::test] async fn test_async_sleep() -> Result<()> { let lua = Lua::new(); diff --git a/tests/function.rs b/tests/function.rs index 6d4848b..b93fac8 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -167,3 +167,36 @@ fn test_function_info() -> Result<()> { Ok(()) } + +#[cfg(feature = "unstable")] +#[test] +fn test_function_wrap() -> Result<()> { + use mlua::Error; + + let lua = Lua::new(); + + lua.globals() + .set("f", Function::wrap(|_, s: String| Ok(s)))?; + lua.load(r#"assert(f("hello") == "hello")"#).exec().unwrap(); + + let mut _i = false; + lua.globals().set( + "f", + Function::wrap_mut(move |lua, ()| { + _i = true; + lua.globals().get::<_, Function>("f")?.call::<_, ()>(()) + }), + )?; + match lua.globals().get::<_, Function>("f")?.call::<_, ()>(()) { + Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { + Error::CallbackError { ref cause, .. } => match *cause.as_ref() { + Error::RecursiveMutCallback { .. } => {} + ref other => panic!("incorrect result: {other:?}"), + }, + ref other => panic!("incorrect result: {other:?}"), + }, + other => panic!("incorrect result: {other:?}"), + }; + + Ok(()) +}