diff --git a/Cargo.toml b/Cargo.toml index 6a72514..98cf545 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,10 +53,9 @@ luajit-src = { version = "210.0.0", optional = true } rustyline = "6.0" criterion = "0.3" trybuild = "1.0" +futures = "0.3.4" hyper = "0.13" tokio = { version = "0.2.18", features = ["full"] } -futures-executor = "0.3.4" -futures-util = "0.3.4" futures-timer = "3.0" [[bench]] diff --git a/src/function.rs b/src/function.rs index 141fcbe..7692575 100644 --- a/src/function.rs +++ b/src/function.rs @@ -9,6 +9,9 @@ use crate::util::{ }; use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti}; +#[cfg(feature = "async")] +use futures_core::future::LocalBoxFuture; + /// Handle to an internal Lua function. #[derive(Clone, Debug)] pub struct Function<'lua>(pub(crate) LuaRef<'lua>); @@ -86,6 +89,44 @@ impl<'lua> Function<'lua> { R::from_lua_multi(results, lua) } + /// Returns a Feature that, when polled, calls `self`, passing `args` as function arguments, + /// and drives the execution. + /// + /// Internaly it wraps the function to an AsyncThread. + /// + /// # Examples + /// + /// ``` + /// use std::time::Duration; + /// use futures_timer::Delay; + /// # use mlua::{Lua, Result}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// let sleep = lua.create_async_function(move |_lua, n: u64| async move { + /// Delay::new(Duration::from_millis(n)).await; + /// Ok(()) + /// })?; + /// + /// sleep.call_async(10).await?; + /// + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "async")] + pub fn call_async<'fut, A, R>(&self, args: A) -> LocalBoxFuture<'fut, Result> + where + 'lua: 'fut, + A: ToLuaMulti<'lua>, + R: FromLuaMulti<'lua> + 'fut, + { + let lua = self.0.lua; + match lua.create_thread(self.clone()) { + Ok(t) => Box::pin(t.into_async(args)), + Err(e) => Box::pin(futures_util::future::err(e)), + } + } + /// Returns a function that, when called, calls `self`, passing `args` as the first set of /// arguments. /// diff --git a/src/lua.rs b/src/lua.rs index a0ca44f..9ba6afe 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -482,39 +482,37 @@ impl Lua { /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. /// /// While executing the function Rust will poll Future and if the result is not ready, call - /// `lua_yield()` returning internal representation of a `Poll::Pending` value. + /// `yield()` passing internal representation of a `Poll::Pending` value. /// - /// The function must be called inside [`Thread`] coroutine to be able to suspend its execution. - /// An executor could be used together with [`ThreadStream`] and mlua will use a provided Waker + /// The function must be called inside Lua coroutine ([`Thread`]) to be able to suspend its execution. + /// An executor should be used to poll [`AsyncThread`] and mlua will take a provided Waker /// in that case. Otherwise noop waker will be used if try to call the function outside of Rust /// executors. /// + /// The family of `call_async()` functions takes care about creating [`Thread`]. + /// /// # Examples /// /// Non blocking sleep: /// /// ``` /// use std::time::Duration; - /// use futures_executor::block_on; /// use futures_timer::Delay; - /// # use mlua::{Lua, Result, Thread}; + /// use mlua::{Lua, Result}; /// /// async fn sleep(_lua: &Lua, n: u64) -> Result<&'static str> { - /// Delay::new(Duration::from_secs(n)).await; + /// Delay::new(Duration::from_millis(n)).await; /// Ok("done") /// } /// - /// # fn main() -> Result<()> { - /// # let lua = Lua::new(); - /// lua.globals().set("async_sleep", lua.create_async_function(sleep)?)?; - /// let thr = lua.load("coroutine.create(function(n) return async_sleep(n) end)").eval::()?; - /// let res: String = block_on(async { - /// thr.into_async(1).await // Sleep 1 second - /// })?; - /// - /// assert_eq!(res, "done"); - /// # Ok(()) - /// # } + /// #[tokio::main] + /// async fn main() -> Result<()> { + /// let lua = Lua::new(); + /// lua.globals().set("sleep", lua.create_async_function(sleep)?)?; + /// let res: String = lua.load("return sleep(...)").call_async(100).await?; // Sleep 100ms + /// assert_eq!(res, "done"); + /// Ok(()) + /// } /// ``` /// /// [`Thread`]: struct.Thread.html @@ -1347,6 +1345,19 @@ impl<'lua, 'a> Chunk<'lua, 'a> { Ok(()) } + /// Asynchronously execute this chunk of code. + /// + /// See [`Chunk::exec`] for more details. + /// + /// [`Chunk::exec`]: struct.Chunk.html#method.exec + #[cfg(feature = "async")] + pub fn exec_async<'fut>(self) -> LocalBoxFuture<'fut, Result<()>> + where + 'lua: 'fut, + { + self.call_async(()) + } + /// Evaluate the chunk as either an expression or block. /// /// If the chunk can be parsed as an expression, this loads and executes the chunk and returns @@ -1356,18 +1367,39 @@ impl<'lua, 'a> Chunk<'lua, 'a> { // First, try interpreting the lua as an expression by adding // "return", then as a statement. This is the same thing the // actual lua repl does. - let mut expression_source = b"return ".to_vec(); - expression_source.extend(self.source); - if let Ok(function) = - self.lua - .load_chunk(&expression_source, self.name.as_ref(), self.env.clone()) - { + if let Ok(function) = self.lua.load_chunk( + &self.expression_source(), + self.name.as_ref(), + self.env.clone(), + ) { function.call(()) } else { self.call(()) } } + /// Asynchronously evaluate the chunk as either an expression or block. + /// + /// See [`Chunk::eval`] for more details. + /// + /// [`Chunk::eval`]: struct.Chunk.html#method.eval + #[cfg(feature = "async")] + pub fn eval_async<'fut, R>(self) -> LocalBoxFuture<'fut, Result> + where + 'lua: 'fut, + R: FromLuaMulti<'lua> + 'fut, + { + if let Ok(function) = self.lua.load_chunk( + &self.expression_source(), + self.name.as_ref(), + self.env.clone(), + ) { + function.call_async(()) + } else { + self.call_async(()) + } + } + /// Load the chunk function and call it with the given arguemnts. /// /// This is equivalent to `into_function` and calling the resulting function. @@ -1375,6 +1407,24 @@ impl<'lua, 'a> Chunk<'lua, 'a> { self.into_function()?.call(args) } + /// Load the chunk function and asynchronously call it with the given arguemnts. + /// + /// See [`Chunk::call`] for more details. + /// + /// [`Chunk::call`]: struct.Chunk.html#method.call + #[cfg(feature = "async")] + pub fn call_async<'fut, A, R>(self, args: A) -> LocalBoxFuture<'fut, Result> + where + 'lua: 'fut, + A: ToLuaMulti<'lua>, + R: FromLuaMulti<'lua> + 'fut, + { + match self.into_function() { + Ok(f) => f.call_async(args), + Err(e) => Box::pin(futures_util::future::err(e)), + } + } + /// Load this chunk into a regular `Function`. /// /// This simply compiles the chunk without actually executing it. @@ -1382,6 +1432,13 @@ impl<'lua, 'a> Chunk<'lua, 'a> { self.lua .load_chunk(self.source, self.name.as_ref(), self.env) } + + fn expression_source(&self) -> Vec { + let mut buf = Vec::with_capacity(b"return ".len() + self.source.len()); + buf.extend(b"return "); + buf.extend(self.source); + buf + } } unsafe fn load_from_std_lib(state: *mut ffi::lua_State, libs: StdLib) { diff --git a/src/thread.rs b/src/thread.rs index 95bc96d..ce1dd57 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -183,13 +183,13 @@ impl<'lua> Thread<'lua> { /// # Examples /// /// ``` - /// # use mlua::{Error, Lua, Result, Thread}; - /// use futures_executor::block_on; - /// use futures_util::stream::TryStreamExt; - /// # fn main() -> Result<()> { + /// # use mlua::{Lua, Result, Thread}; + /// use futures::stream::TryStreamExt; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { /// # let lua = Lua::new(); /// let thread: Thread = lua.load(r#" - /// coroutine.create(function(sum) + /// coroutine.create(function (sum) /// for i = 1,10 do /// sum = sum + i /// coroutine.yield(sum) @@ -198,16 +198,13 @@ impl<'lua> Thread<'lua> { /// end) /// "#).eval()?; /// - /// let result = block_on(async { - /// let mut s = thread.into_async::<_, i64>(1); - /// let mut sum = 0; - /// while let Some(n) = s.try_next().await? { - /// sum += n; - /// } - /// Ok::<_, Error>(sum) - /// })?; + /// let mut stream = thread.into_async::<_, i64>(1); + /// let mut sum = 0; + /// while let Some(n) = stream.try_next().await? { + /// sum += n; + /// } /// - /// assert_eq!(result, 286); + /// assert_eq!(sum, 286); /// /// # Ok(()) /// # } diff --git a/tests/async.rs b/tests/async.rs new file mode 100644 index 0000000..e02eb12 --- /dev/null +++ b/tests/async.rs @@ -0,0 +1,173 @@ +#![cfg(feature = "async")] + +use std::rc::Rc; +use std::time::Duration; + +use futures_util::stream::TryStreamExt; + +use mlua::{Error, Function, Lua, Result}; + +#[tokio::test] +async fn test_async_function() -> Result<()> { + let lua = Lua::new(); + + let f = lua + .create_async_function(|_lua, (a, b, c): (i64, i64, i64)| async move { Ok((a + b) * c) })?; + lua.globals().set("f", f)?; + + let res: i64 = lua.load("f(1, 2, 3)").eval_async().await?; + assert_eq!(res, 9); + + Ok(()) +} + +#[tokio::test] +async fn test_async_sleep() -> Result<()> { + let lua = Lua::new(); + + let sleep = lua.create_async_function(move |_lua, n: u64| async move { + futures_timer::Delay::new(Duration::from_millis(n)).await; + Ok(format!("elapsed:{}ms", n)) + })?; + lua.globals().set("sleep", sleep)?; + + let res: String = lua.load(r"return sleep(...)").call_async(100).await?; + assert_eq!(res, "elapsed:100ms"); + + Ok(()) +} + +#[tokio::test] +async fn test_async_call() -> Result<()> { + let lua = Lua::new(); + + let sleep = lua.create_async_function(|_lua, name: String| async move { + futures_timer::Delay::new(Duration::from_millis(10)).await; + Ok(format!("hello, {}!", name)) + })?; + + match sleep.call::<_, ()>("alex") { + Err(Error::RuntimeError(_)) => {} + _ => panic!( + "non-async executing async function must fail on the yield stage with RuntimeError" + ), + }; + + assert_eq!(sleep.call_async::<_, String>("alex").await?, "hello, alex!"); + + // Executing non-async functions using async call is allowed + let sum = lua.create_function(|_lua, (a, b): (i64, i64)| return Ok(a + b))?; + assert_eq!(sum.call_async::<_, i64>((5, 1)).await?, 6); + + Ok(()) +} + +#[tokio::test] +async fn test_async_bind_call() -> Result<()> { + let lua = Lua::new(); + + let less = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { Ok(a < b) })?; + + let less_bound = less.bind(0)?; + lua.globals().set("f", less_bound)?; + + assert_eq!(lua.load("f(-1)").eval_async::().await?, false); + assert_eq!(lua.load("f(1)").eval_async::().await?, true); + + Ok(()) +} + +#[tokio::test] +async fn test_async_handle_yield() -> Result<()> { + let lua = Lua::new(); + + let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { + futures_timer::Delay::new(Duration::from_millis(100)).await; + Ok(a + b) + })?; + + lua.globals().set("sleep_sum", sum)?; + + let res: String = lua + .load( + r#" + sum = sleep_sum(6, 7) + assert(sum == 13) + coroutine.yield("in progress") + return "done" + "#, + ) + .call_async(()) + .await?; + + assert_eq!(res, "done"); + + let min = lua + .load( + r#" + function (a, b) + coroutine.yield("ignore me") + if a < b then return a else return b end + end + "#, + ) + .eval::()?; + assert_eq!(min.call_async::<_, i64>((-1, 1)).await?, -1); + + Ok(()) +} + +#[tokio::test] +async fn test_async_thread_stream() -> Result<()> { + let lua = Lua::new(); + + let thread = lua.create_thread( + lua.load( + r#" + function (sum) + for i = 1,10 do + sum = sum + i + coroutine.yield(sum) + end + return sum + end + "#, + ) + .eval()?, + )?; + + let mut stream = thread.into_async::<_, i64>(1); + let mut sum = 0; + while let Some(n) = stream.try_next().await? { + sum += n; + } + + assert_eq!(sum, 286); + + Ok(()) +} + +#[tokio::test] +async fn test_async_thread() -> Result<()> { + let lua = Lua::new(); + + let cnt = Rc::new(100); // sleep 100ms + let cnt2 = cnt.clone(); + let f = lua.create_async_function(move |_lua, ()| { + let cnt3 = cnt2.clone(); + async move { + futures_timer::Delay::new(Duration::from_millis(*cnt3.as_ref())).await; + Ok("done") + } + })?; + + let res: String = lua.create_thread(f)?.into_async(()).await?; + + assert_eq!(res, "done"); + + assert_eq!(Rc::strong_count(&cnt), 2); + lua.gc_collect()?; // thread_s is non-resumable and subject to garbage collection + assert_eq!(Rc::strong_count(&cnt), 1); + + Ok(()) +} diff --git a/tests/function.rs b/tests/function.rs index ef2f9fb..cc1c9ab 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -1,10 +1,4 @@ -#![allow(unused_imports)] - -use std::{string::String as StdString, time::Duration}; - -use futures_executor::block_on; - -use mlua::{Error, Function, Lua, Result, String, Thread}; +use mlua::{Function, Lua, Result, String}; #[test] fn test_function() -> Result<()> { @@ -81,34 +75,3 @@ fn test_rust_function() -> Result<()> { Ok(()) } - -#[cfg(feature = "async")] -#[tokio::test] -async fn test_async_function() -> Result<()> { - let lua = Lua::new(); - - let f = lua.create_async_function(move |_lua, n: u64| async move { - futures_timer::Delay::new(Duration::from_secs(n)).await; - Ok("hello") - })?; - lua.globals().set("rust_async_sleep", f)?; - - let thread = lua - .load( - r#" - coroutine.create(function () - ret = rust_async_sleep(1) - assert(ret == "hello") - coroutine.yield() - return "world" - end) - "#, - ) - .eval::()?; - - let fut = thread.into_async(()); - let ret: StdString = fut.await?; - assert_eq!(ret, "world"); - - Ok(()) -} diff --git a/tests/thread.rs b/tests/thread.rs index 985fd88..666a7e6 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -1,11 +1,4 @@ -#![allow(unused_imports)] - use std::panic::catch_unwind; -use std::rc::Rc; -use std::time::Duration; - -use futures_executor::block_on; -use futures_util::stream::TryStreamExt; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; @@ -100,38 +93,6 @@ fn test_thread() -> Result<()> { Ok(()) } -#[cfg(feature = "async")] -#[tokio::test] -async fn test_thread_stream() -> Result<()> { - let lua = Lua::new(); - - let thread = lua.create_thread( - lua.load( - r#" - function (s) - local sum = s - for i = 1,10 do - sum = sum + i - coroutine.yield(sum) - end - return sum - end - "#, - ) - .eval()?, - )?; - - let mut s = thread.into_async::<_, i64>(0); - let mut sum = 0; - while let Some(n) = s.try_next().await? { - sum += n; - } - - assert_eq!(sum, 275); - - Ok(()) -} - #[test] fn coroutine_from_closure() -> Result<()> { let lua = Lua::new(); @@ -167,30 +128,3 @@ fn coroutine_panic() { Err(p) => assert!(*p.downcast::<&str>().unwrap() == "test_panic"), } } - -#[cfg(feature = "async")] -#[tokio::test] -async fn test_thread_async() -> Result<()> { - let lua = Lua::new(); - - let cnt = Rc::new(1); // sleep 1 second - let cnt2 = cnt.clone(); - let f = lua.create_async_function(move |_lua, ()| { - let cnt3 = cnt2.clone(); - async move { - futures_timer::Delay::new(Duration::from_secs(*cnt3.as_ref())).await; - Ok("hello") - } - })?; - - let mut thread_s = lua.create_thread(f)?.into_async(()); - let val: String = thread_s.try_next().await?.unwrap_or_default(); - - // thread_s is non-resumable and subject to garbage collection - - lua.gc_collect()?; - assert_eq!(Rc::strong_count(&cnt), 1); - assert_eq!(val, "hello"); - - Ok(()) -}