diff --git a/src/conversion.rs b/src/conversion.rs index f9be6f6..949ab36 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -4,10 +4,12 @@ use std::string::String as StdString; use error::*; use types::{Integer, LightUserData, Number}; -use lua::*; use string::String; use table::Table; use userdata::{AnyUserData, UserData}; +use function::Function; +use thread::Thread; +use lua::{FromLua, Lua, Nil, ToLua, Value}; impl<'lua> ToLua<'lua> for Value<'lua> { fn to_lua(self, _: &'lua Lua) -> Result> { diff --git a/src/error.rs b/src/error.rs index a509432..4fa6912 100644 --- a/src/error.rs +++ b/src/error.rs @@ -107,7 +107,9 @@ impl fmt::Display for Error { match *self { Error::SyntaxError { ref message, .. } => write!(fmt, "syntax error: {}", message), Error::RuntimeError(ref msg) => write!(fmt, "runtime error: {}", msg), - Error::GarbageCollectorError(ref msg) => write!(fmt, "garbage collector error: {}", msg), + Error::GarbageCollectorError(ref msg) => { + write!(fmt, "garbage collector error: {}", msg) + } Error::ToLuaConversionError { from, to, diff --git a/src/function.rs b/src/function.rs new file mode 100644 index 0000000..f087f20 --- /dev/null +++ b/src/function.rs @@ -0,0 +1,239 @@ +use std::os::raw::c_int; + +use ffi; +use error::*; +use util::*; +use types::LuaRef; +use lua::{FromLuaMulti, MultiValue, ToLuaMulti}; + +/// Handle to an internal Lua function. +#[derive(Clone, Debug)] +pub struct Function<'lua>(pub(crate) LuaRef<'lua>); + +impl<'lua> Function<'lua> { + /// Calls the function, passing `args` as function arguments. + /// + /// The function's return values are converted to the generic type `R`. + /// + /// # Examples + /// + /// Call Lua's built-in `tostring` function: + /// + /// ``` + /// # extern crate rlua; + /// # use rlua::{Lua, Function, Result}; + /// # fn try_main() -> Result<()> { + /// let lua = Lua::new(); + /// let globals = lua.globals(); + /// + /// let tostring: Function = globals.get("tostring")?; + /// + /// assert_eq!(tostring.call::<_, String>(123)?, "123"); + /// + /// # Ok(()) + /// # } + /// # fn main() { + /// # try_main().unwrap(); + /// # } + /// ``` + /// + /// Call a function with multiple arguments: + /// + /// ``` + /// # extern crate rlua; + /// # use rlua::{Lua, Function, Result}; + /// # fn try_main() -> Result<()> { + /// let lua = Lua::new(); + /// + /// let sum: Function = lua.eval(r#" + /// function(a, b) + /// return a + b + /// end + /// "#, None)?; + /// + /// assert_eq!(sum.call::<_, u32>((3, 4))?, 3 + 4); + /// + /// # Ok(()) + /// # } + /// # fn main() { + /// # try_main().unwrap(); + /// # } + /// ``` + pub fn call, R: FromLuaMulti<'lua>>(&self, args: A) -> Result { + let lua = self.0.lua; + unsafe { + stack_err_guard(lua.state, 0, || { + let args = args.to_lua_multi(lua)?; + let nargs = args.len() as c_int; + check_stack(lua.state, nargs + 3); + + ffi::lua_pushcfunction(lua.state, error_traceback); + let stack_start = ffi::lua_gettop(lua.state); + lua.push_ref(lua.state, &self.0); + for arg in args { + lua.push_value(lua.state, arg); + } + let ret = ffi::lua_pcall(lua.state, nargs, ffi::LUA_MULTRET, stack_start); + if ret != ffi::LUA_OK { + return Err(pop_error(lua.state, ret)); + } + let nresults = ffi::lua_gettop(lua.state) - stack_start; + let mut results = MultiValue::new(); + check_stack(lua.state, 1); + for _ in 0..nresults { + results.push_front(lua.pop_value(lua.state)); + } + ffi::lua_pop(lua.state, 1); + R::from_lua_multi(results, lua) + }) + } + } + + /// Returns a function that, when called, calls `self`, passing `args` as the first set of + /// arguments. + /// + /// If any arguments are passed to the returned function, they will be passed after `args`. + /// + /// # Examples + /// + /// ``` + /// # extern crate rlua; + /// # use rlua::{Lua, Function, Result}; + /// # fn try_main() -> Result<()> { + /// let lua = Lua::new(); + /// + /// let sum: Function = lua.eval(r#" + /// function(a, b) + /// return a + b + /// end + /// "#, None)?; + /// + /// let bound_a = sum.bind(1)?; + /// assert_eq!(bound_a.call::<_, u32>(2)?, 1 + 2); + /// + /// let bound_a_and_b = sum.bind(13)?.bind(57)?; + /// assert_eq!(bound_a_and_b.call::<_, u32>(())?, 13 + 57); + /// + /// # Ok(()) + /// # } + /// # fn main() { + /// # try_main().unwrap(); + /// # } + /// ``` + pub fn bind>(&self, args: A) -> Result> { + unsafe extern "C" fn bind_call_impl(state: *mut ffi::lua_State) -> c_int { + let nargs = ffi::lua_gettop(state); + let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(2)) as c_int; + check_stack(state, nbinds + 2); + + ffi::lua_settop(state, nargs + nbinds + 1); + ffi::lua_rotate(state, -(nargs + nbinds + 1), nbinds + 1); + + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); + ffi::lua_replace(state, 1); + + for i in 0..nbinds { + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 3)); + ffi::lua_replace(state, i + 2); + } + + ffi::lua_call(state, nargs + nbinds, ffi::LUA_MULTRET); + ffi::lua_gettop(state) + } + + let lua = self.0.lua; + unsafe { + stack_err_guard(lua.state, 0, || { + let args = args.to_lua_multi(lua)?; + let nargs = args.len() as c_int; + + check_stack(lua.state, nargs + 3); + lua.push_ref(lua.state, &self.0); + ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer); + for arg in args { + lua.push_value(lua.state, arg); + } + + protect_lua_call(lua.state, nargs + 2, 1, |state| { + ffi::lua_pushcclosure(state, bind_call_impl, nargs + 2); + })?; + + Ok(Function(lua.pop_ref(lua.state))) + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::Function; + use string::String; + use lua::Lua; + + #[test] + fn test_function() { + let lua = Lua::new(); + let globals = lua.globals(); + lua.exec::<()>( + r#" + function concat(arg1, arg2) + return arg1 .. arg2 + end + "#, + None, + ).unwrap(); + + let concat = globals.get::<_, Function>("concat").unwrap(); + assert_eq!(concat.call::<_, String>(("foo", "bar")).unwrap(), "foobar"); + } + + #[test] + fn test_bind() { + let lua = Lua::new(); + let globals = lua.globals(); + lua.exec::<()>( + r#" + function concat(...) + local res = "" + for _, s in pairs({...}) do + res = res..s + end + return res + end + "#, + None, + ).unwrap(); + + let mut concat = globals.get::<_, Function>("concat").unwrap(); + concat = concat.bind("foo").unwrap(); + concat = concat.bind("bar").unwrap(); + concat = concat.bind(("baz", "baf")).unwrap(); + assert_eq!( + concat.call::<_, String>(("hi", "wut")).unwrap(), + "foobarbazbafhiwut" + ); + } + + #[test] + fn test_rust_function() { + let lua = Lua::new(); + let globals = lua.globals(); + lua.exec::<()>( + r#" + function lua_function() + return rust_function() + end + + -- Test to make sure chunk return is ignored + return 1 + "#, + None, + ).unwrap(); + + let lua_function = globals.get::<_, Function>("lua_function").unwrap(); + let rust_function = lua.create_function(|_, ()| Ok("hello")).unwrap(); + + globals.set("rust_function", rust_function).unwrap(); + assert_eq!(lua_function.call::<_, String>(()).unwrap(), "hello"); + } +} diff --git a/src/lib.rs b/src/lib.rs index d5af254..cc54e27 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,6 +53,8 @@ mod conversion; mod multi; mod string; mod table; +mod function; +mod thread; mod userdata; #[cfg(test)] @@ -63,8 +65,9 @@ pub use types::{Integer, LightUserData, Number}; pub use multi::Variadic; pub use string::String; pub use table::{Table, TablePairs, TableSequence}; +pub use function::Function; +pub use thread::{Thread, ThreadStatus}; pub use userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; -pub use lua::{FromLua, FromLuaMulti, Function, Lua, MultiValue, Nil, Thread, ThreadStatus, ToLua, - ToLuaMulti, Value}; +pub use lua::{FromLua, FromLuaMulti, Lua, MultiValue, Nil, ToLua, ToLuaMulti, Value}; pub mod prelude; diff --git a/src/lua.rs b/src/lua.rs index 471f717..35e4af7 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -17,6 +17,8 @@ use util::*; use types::{Callback, Integer, LightUserData, LuaRef, Number}; use string::String; use table::Table; +use function::Function; +use thread::Thread; use userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; /// A dynamically typed Lua value. @@ -144,298 +146,6 @@ pub trait FromLuaMulti<'lua>: Sized { fn from_lua_multi(values: MultiValue<'lua>, lua: &'lua Lua) -> Result; } -/// Handle to an internal Lua function. -#[derive(Clone, Debug)] -pub struct Function<'lua>(LuaRef<'lua>); - -impl<'lua> Function<'lua> { - /// Calls the function, passing `args` as function arguments. - /// - /// The function's return values are converted to the generic type `R`. - /// - /// # Examples - /// - /// Call Lua's built-in `tostring` function: - /// - /// ``` - /// # extern crate rlua; - /// # use rlua::{Lua, Function, Result}; - /// # fn try_main() -> Result<()> { - /// let lua = Lua::new(); - /// let globals = lua.globals(); - /// - /// let tostring: Function = globals.get("tostring")?; - /// - /// assert_eq!(tostring.call::<_, String>(123)?, "123"); - /// - /// # Ok(()) - /// # } - /// # fn main() { - /// # try_main().unwrap(); - /// # } - /// ``` - /// - /// Call a function with multiple arguments: - /// - /// ``` - /// # extern crate rlua; - /// # use rlua::{Lua, Function, Result}; - /// # fn try_main() -> Result<()> { - /// let lua = Lua::new(); - /// - /// let sum: Function = lua.eval(r#" - /// function(a, b) - /// return a + b - /// end - /// "#, None)?; - /// - /// assert_eq!(sum.call::<_, u32>((3, 4))?, 3 + 4); - /// - /// # Ok(()) - /// # } - /// # fn main() { - /// # try_main().unwrap(); - /// # } - /// ``` - pub fn call, R: FromLuaMulti<'lua>>(&self, args: A) -> Result { - let lua = self.0.lua; - unsafe { - stack_err_guard(lua.state, 0, || { - let args = args.to_lua_multi(lua)?; - let nargs = args.len() as c_int; - check_stack(lua.state, nargs + 3); - - ffi::lua_pushcfunction(lua.state, error_traceback); - let stack_start = ffi::lua_gettop(lua.state); - lua.push_ref(lua.state, &self.0); - for arg in args { - lua.push_value(lua.state, arg); - } - let ret = ffi::lua_pcall(lua.state, nargs, ffi::LUA_MULTRET, stack_start); - if ret != ffi::LUA_OK { - return Err(pop_error(lua.state, ret)); - } - let nresults = ffi::lua_gettop(lua.state) - stack_start; - let mut results = MultiValue::new(); - check_stack(lua.state, 1); - for _ in 0..nresults { - results.push_front(lua.pop_value(lua.state)); - } - ffi::lua_pop(lua.state, 1); - R::from_lua_multi(results, lua) - }) - } - } - - /// Returns a function that, when called, calls `self`, passing `args` as the first set of - /// arguments. - /// - /// If any arguments are passed to the returned function, they will be passed after `args`. - /// - /// # Examples - /// - /// ``` - /// # extern crate rlua; - /// # use rlua::{Lua, Function, Result}; - /// # fn try_main() -> Result<()> { - /// let lua = Lua::new(); - /// - /// let sum: Function = lua.eval(r#" - /// function(a, b) - /// return a + b - /// end - /// "#, None)?; - /// - /// let bound_a = sum.bind(1)?; - /// assert_eq!(bound_a.call::<_, u32>(2)?, 1 + 2); - /// - /// let bound_a_and_b = sum.bind(13)?.bind(57)?; - /// assert_eq!(bound_a_and_b.call::<_, u32>(())?, 13 + 57); - /// - /// # Ok(()) - /// # } - /// # fn main() { - /// # try_main().unwrap(); - /// # } - /// ``` - pub fn bind>(&self, args: A) -> Result> { - unsafe extern "C" fn bind_call_impl(state: *mut ffi::lua_State) -> c_int { - let nargs = ffi::lua_gettop(state); - let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(2)) as c_int; - check_stack(state, nbinds + 2); - - ffi::lua_settop(state, nargs + nbinds + 1); - ffi::lua_rotate(state, -(nargs + nbinds + 1), nbinds + 1); - - ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); - ffi::lua_replace(state, 1); - - for i in 0..nbinds { - ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 3)); - ffi::lua_replace(state, i + 2); - } - - ffi::lua_call(state, nargs + nbinds, ffi::LUA_MULTRET); - ffi::lua_gettop(state) - } - - let lua = self.0.lua; - unsafe { - stack_err_guard(lua.state, 0, || { - let args = args.to_lua_multi(lua)?; - let nargs = args.len() as c_int; - - check_stack(lua.state, nargs + 3); - lua.push_ref(lua.state, &self.0); - ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer); - for arg in args { - lua.push_value(lua.state, arg); - } - - protect_lua_call(lua.state, nargs + 2, 1, |state| { - ffi::lua_pushcclosure(state, bind_call_impl, nargs + 2); - })?; - - Ok(Function(lua.pop_ref(lua.state))) - }) - } - } -} - -/// Status of a Lua thread (or coroutine). -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum ThreadStatus { - /// The thread was just created, or is suspended because it has called `coroutine.yield`. - /// - /// If a thread is in this state, it can be resumed by calling [`Thread::resume`]. - /// - /// [`Thread::resume`]: struct.Thread.html#method.resume - Resumable, - /// Either the thread has finished executing, or the thread is currently running. - Unresumable, - /// The thread has raised a Lua error during execution. - Error, -} - -/// Handle to an internal Lua thread (or coroutine). -#[derive(Clone, Debug)] -pub struct Thread<'lua>(LuaRef<'lua>); - -impl<'lua> Thread<'lua> { - /// Resumes execution of this thread. - /// - /// Equivalent to `coroutine.resume`. - /// - /// Passes `args` as arguments to the thread. If the coroutine has called `coroutine.yield`, it - /// will return these arguments. Otherwise, the coroutine wasn't yet started, so the arguments - /// are passed to its main function. - /// - /// If the thread is no longer in `Active` state (meaning it has finished execution or - /// encountered an error), this will return `Err(CoroutineInactive)`, otherwise will return `Ok` - /// as follows: - /// - /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread - /// `return`s values from its main function, returns those. - /// - /// # Examples - /// - /// ``` - /// # extern crate rlua; - /// # use rlua::{Lua, Thread, Error, Result}; - /// # fn try_main() -> Result<()> { - /// let lua = Lua::new(); - /// let thread: Thread = lua.eval(r#" - /// coroutine.create(function(arg) - /// assert(arg == 42) - /// local yieldarg = coroutine.yield(123) - /// assert(yieldarg == 43) - /// return 987 - /// end) - /// "#, None).unwrap(); - /// - /// assert_eq!(thread.resume::<_, u32>(42).unwrap(), 123); - /// assert_eq!(thread.resume::<_, u32>(43).unwrap(), 987); - /// - /// // The coroutine has now returned, so `resume` will fail - /// match thread.resume::<_, u32>(()) { - /// Err(Error::CoroutineInactive) => {}, - /// unexpected => panic!("unexpected result {:?}", unexpected), - /// } - /// # Ok(()) - /// # } - /// # fn main() { - /// # try_main().unwrap(); - /// # } - /// ``` - pub fn resume(&self, args: A) -> Result - where - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>, - { - let lua = self.0.lua; - unsafe { - stack_err_guard(lua.state, 0, || { - check_stack(lua.state, 1); - - lua.push_ref(lua.state, &self.0); - let thread_state = ffi::lua_tothread(lua.state, -1); - - let status = ffi::lua_status(thread_state); - if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 { - return Err(Error::CoroutineInactive); - } - - ffi::lua_pop(lua.state, 1); - - let args = args.to_lua_multi(lua)?; - let nargs = args.len() as c_int; - check_stack(thread_state, nargs); - - for arg in args { - lua.push_value(thread_state, arg); - } - - let ret = ffi::lua_resume(thread_state, lua.state, nargs); - if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { - error_traceback(thread_state); - return Err(pop_error(thread_state, ret)); - } - - let nresults = ffi::lua_gettop(thread_state); - let mut results = MultiValue::new(); - check_stack(thread_state, 1); - for _ in 0..nresults { - results.push_front(lua.pop_value(thread_state)); - } - R::from_lua_multi(results, lua) - }) - } - } - - /// Gets the status of the thread. - pub fn status(&self) -> ThreadStatus { - let lua = self.0.lua; - unsafe { - stack_guard(lua.state, 0, || { - check_stack(lua.state, 1); - - lua.push_ref(lua.state, &self.0); - let thread_state = ffi::lua_tothread(lua.state, -1); - ffi::lua_pop(lua.state, 1); - - let status = ffi::lua_status(thread_state); - if status != ffi::LUA_OK && status != ffi::LUA_YIELD { - ThreadStatus::Error - } else if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 { - ThreadStatus::Resumable - } else { - ThreadStatus::Unresumable - } - }) - } - } -} - /// Top level Lua struct which holds the Lua state itself. pub struct Lua { pub(crate) state: *mut ffi::lua_State, @@ -667,9 +377,8 @@ impl Lua { stack_err_guard(self.state, 0, move || { check_stack(self.state, 2); - let thread_state = protect_lua_call(self.state, 0, 1, |state| { - ffi::lua_newthread(state) - })?; + let thread_state = + protect_lua_call(self.state, 0, 1, |state| ffi::lua_newthread(state))?; self.push_ref(thread_state, &func.0); Ok(Thread(self.pop_ref(self.state))) @@ -723,9 +432,8 @@ impl Lua { check_stack(self.state, 2); let ty = v.type_name(); self.push_value(self.state, v); - let s = protect_lua_call(self.state, 1, 1, |state| { - ffi::lua_tostring(state, -1) - })?; + let s = + protect_lua_call(self.state, 1, 1, |state| ffi::lua_tostring(state, -1))?; if s.is_null() { ffi::lua_pop(self.state, 1); Err(Error::FromLuaConversionError { diff --git a/src/table.rs b/src/table.rs index fa2d506..f022f9f 100644 --- a/src/table.rs +++ b/src/table.rs @@ -435,7 +435,7 @@ where #[cfg(test)] mod tests { use super::Table; - use error::{Result}; + use error::Result; use lua::{Lua, Nil, Value}; #[test] @@ -565,7 +565,10 @@ mod tests { let table = lua.create_table().unwrap(); let metatable = lua.create_table().unwrap(); metatable - .set("__index", lua.create_function(|_, ()| Ok("index_value")).unwrap()) + .set( + "__index", + lua.create_function(|_, ()| Ok("index_value")).unwrap(), + ) .unwrap(); table.set_metatable(Some(metatable)); assert_eq!(table.get::<_, String>("any_key").unwrap(), "index_value"); diff --git a/src/tests.rs b/src/tests.rs index a1c82df..9401e7c 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -2,7 +2,7 @@ use std::fmt; use std::error; use std::panic::catch_unwind; -use {Error, ExternalError, Function, Lua, Result, Table, Thread, ThreadStatus, Value, Variadic}; +use {Error, ExternalError, Function, Lua, Result, Table, Value, Variadic}; #[test] fn test_load() { @@ -81,73 +81,6 @@ fn test_eval() { } } -#[test] -fn test_function() { - let lua = Lua::new(); - let globals = lua.globals(); - lua.exec::<()>( - r#" - function concat(arg1, arg2) - return arg1 .. arg2 - end - "#, - None, - ).unwrap(); - - let concat = globals.get::<_, Function>("concat").unwrap(); - assert_eq!(concat.call::<_, String>(("foo", "bar")).unwrap(), "foobar"); -} - -#[test] -fn test_bind() { - let lua = Lua::new(); - let globals = lua.globals(); - lua.exec::<()>( - r#" - function concat(...) - local res = "" - for _, s in pairs({...}) do - res = res..s - end - return res - end - "#, - None, - ).unwrap(); - - let mut concat = globals.get::<_, Function>("concat").unwrap(); - concat = concat.bind("foo").unwrap(); - concat = concat.bind("bar").unwrap(); - concat = concat.bind(("baz", "baf")).unwrap(); - assert_eq!( - concat.call::<_, String>(("hi", "wut")).unwrap(), - "foobarbazbafhiwut" - ); -} - -#[test] -fn test_rust_function() { - let lua = Lua::new(); - let globals = lua.globals(); - lua.exec::<()>( - r#" - function lua_function() - return rust_function() - end - - -- Test to make sure chunk return is ignored - return 1 - "#, - None, - ).unwrap(); - - let lua_function = globals.get::<_, Function>("lua_function").unwrap(); - let rust_function = lua.create_function(|_, ()| Ok("hello")).unwrap(); - - globals.set("rust_function", rust_function).unwrap(); - assert_eq!(lua_function.call::<_, String>(()).unwrap(), "hello"); -} - #[test] fn test_lua_multi() { let lua = Lua::new(); @@ -270,8 +203,9 @@ fn test_error() { None, ).unwrap(); - let rust_error_function = - lua.create_function(|_, ()| -> Result<()> { Err(TestError.to_lua_err()) }).unwrap(); + let rust_error_function = lua.create_function(|_, ()| -> Result<()> { + Err(TestError.to_lua_err()) + }).unwrap(); globals .set("rust_error_function", rust_error_function) .unwrap(); @@ -377,92 +311,6 @@ fn test_error() { }; } -#[test] -fn test_thread() { - let lua = Lua::new(); - let thread = lua.create_thread( - lua.eval::( - r#" - function (s) - local sum = s - for i = 1,4 do - sum = sum + coroutine.yield(sum) - end - return sum - end - "#, - None, - ).unwrap(), - ).unwrap(); - - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(0).unwrap(), 0); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(1).unwrap(), 1); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(2).unwrap(), 3); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(3).unwrap(), 6); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(4).unwrap(), 10); - assert_eq!(thread.status(), ThreadStatus::Unresumable); - - let accumulate = lua.create_thread( - lua.eval::( - r#" - function (sum) - while true do - sum = sum + coroutine.yield(sum) - end - end - "#, - None, - ).unwrap(), - ).unwrap(); - - for i in 0..4 { - accumulate.resume::<_, ()>(i).unwrap(); - } - assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), 10); - assert_eq!(accumulate.status(), ThreadStatus::Resumable); - assert!(accumulate.resume::<_, ()>("error").is_err()); - assert_eq!(accumulate.status(), ThreadStatus::Error); - - let thread = lua.eval::( - r#" - coroutine.create(function () - while true do - coroutine.yield(42) - end - end) - "#, - None, - ).unwrap(); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(()).unwrap(), 42); - - let thread: Thread = lua.eval( - r#" - coroutine.create(function(arg) - assert(arg == 42) - local yieldarg = coroutine.yield(123) - assert(yieldarg == 43) - return 987 - end) - "#, - None, - ).unwrap(); - - assert_eq!(thread.resume::<_, u32>(42).unwrap(), 123); - assert_eq!(thread.resume::<_, u32>(43).unwrap(), 987); - - match thread.resume::<_, u32>(()) { - Err(Error::CoroutineInactive) => {} - Err(_) => panic!("resuming dead coroutine error is not CoroutineInactive kind"), - _ => panic!("resuming dead coroutine did not return error"), - } -} - #[test] fn test_result_conversions() { let lua = Lua::new(); @@ -473,7 +321,8 @@ fn test_result_conversions() { "only through failure can we succeed".to_lua_err(), )) }).unwrap(); - let ok = lua.create_function(|_, ()| Ok(Ok::<_, Error>("!".to_owned()))).unwrap(); + let ok = lua.create_function(|_, ()| Ok(Ok::<_, Error>("!".to_owned()))) + .unwrap(); globals.set("err", err).unwrap(); globals.set("ok", ok).unwrap(); @@ -516,29 +365,6 @@ fn test_num_conversion() { assert!(globals.get::<_, i64>("n").is_err()); } -#[test] -fn coroutine_from_closure() { - let lua = Lua::new(); - let thrd_main = lua.create_function(|_, ()| Ok(())).unwrap(); - lua.globals().set("main", thrd_main).unwrap(); - let thrd: Thread = lua.eval("coroutine.create(main)", None).unwrap(); - thrd.resume::<_, ()>(()).unwrap(); -} - -#[test] -#[should_panic] -fn coroutine_panic() { - let lua = Lua::new(); - let thrd_main = lua.create_function(|lua, ()| { - // whoops, 'main' has a wrong type - let _coro: u32 = lua.globals().get("main").unwrap(); - Ok(()) - }).unwrap(); - lua.globals().set("main", thrd_main.clone()).unwrap(); - let thrd: Thread = lua.create_thread(thrd_main).unwrap(); - thrd.resume::<_, ()>(()).unwrap(); -} - #[test] fn test_pcall_xpcall() { let lua = Lua::new(); diff --git a/src/thread.rs b/src/thread.rs new file mode 100644 index 0000000..d784865 --- /dev/null +++ b/src/thread.rs @@ -0,0 +1,258 @@ +use std::os::raw::c_int; + +use ffi; +use error::*; +use util::*; +use types::LuaRef; +use lua::{FromLuaMulti, MultiValue, ToLuaMulti}; + +/// Status of a Lua thread (or coroutine). +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ThreadStatus { + /// The thread was just created, or is suspended because it has called `coroutine.yield`. + /// + /// If a thread is in this state, it can be resumed by calling [`Thread::resume`]. + /// + /// [`Thread::resume`]: struct.Thread.html#method.resume + Resumable, + /// Either the thread has finished executing, or the thread is currently running. + Unresumable, + /// The thread has raised a Lua error during execution. + Error, +} + +/// Handle to an internal Lua thread (or coroutine). +#[derive(Clone, Debug)] +pub struct Thread<'lua>(pub(crate) LuaRef<'lua>); + +impl<'lua> Thread<'lua> { + /// Resumes execution of this thread. + /// + /// Equivalent to `coroutine.resume`. + /// + /// Passes `args` as arguments to the thread. If the coroutine has called `coroutine.yield`, it + /// will return these arguments. Otherwise, the coroutine wasn't yet started, so the arguments + /// are passed to its main function. + /// + /// If the thread is no longer in `Active` state (meaning it has finished execution or + /// encountered an error), this will return `Err(CoroutineInactive)`, otherwise will return `Ok` + /// as follows: + /// + /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread + /// `return`s values from its main function, returns those. + /// + /// # Examples + /// + /// ``` + /// # extern crate rlua; + /// # use rlua::{Lua, Thread, Error, Result}; + /// # fn try_main() -> Result<()> { + /// let lua = Lua::new(); + /// let thread: Thread = lua.eval(r#" + /// coroutine.create(function(arg) + /// assert(arg == 42) + /// local yieldarg = coroutine.yield(123) + /// assert(yieldarg == 43) + /// return 987 + /// end) + /// "#, None).unwrap(); + /// + /// assert_eq!(thread.resume::<_, u32>(42).unwrap(), 123); + /// assert_eq!(thread.resume::<_, u32>(43).unwrap(), 987); + /// + /// // The coroutine has now returned, so `resume` will fail + /// match thread.resume::<_, u32>(()) { + /// Err(Error::CoroutineInactive) => {}, + /// unexpected => panic!("unexpected result {:?}", unexpected), + /// } + /// # Ok(()) + /// # } + /// # fn main() { + /// # try_main().unwrap(); + /// # } + /// ``` + pub fn resume(&self, args: A) -> Result + where + A: ToLuaMulti<'lua>, + R: FromLuaMulti<'lua>, + { + let lua = self.0.lua; + unsafe { + stack_err_guard(lua.state, 0, || { + check_stack(lua.state, 1); + + lua.push_ref(lua.state, &self.0); + let thread_state = ffi::lua_tothread(lua.state, -1); + + let status = ffi::lua_status(thread_state); + if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 { + return Err(Error::CoroutineInactive); + } + + ffi::lua_pop(lua.state, 1); + + let args = args.to_lua_multi(lua)?; + let nargs = args.len() as c_int; + check_stack(thread_state, nargs); + + for arg in args { + lua.push_value(thread_state, arg); + } + + let ret = ffi::lua_resume(thread_state, lua.state, nargs); + if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { + error_traceback(thread_state); + return Err(pop_error(thread_state, ret)); + } + + let nresults = ffi::lua_gettop(thread_state); + let mut results = MultiValue::new(); + check_stack(thread_state, 1); + for _ in 0..nresults { + results.push_front(lua.pop_value(thread_state)); + } + R::from_lua_multi(results, lua) + }) + } + } + + /// Gets the status of the thread. + pub fn status(&self) -> ThreadStatus { + let lua = self.0.lua; + unsafe { + stack_guard(lua.state, 0, || { + check_stack(lua.state, 1); + + lua.push_ref(lua.state, &self.0); + let thread_state = ffi::lua_tothread(lua.state, -1); + ffi::lua_pop(lua.state, 1); + + let status = ffi::lua_status(thread_state); + if status != ffi::LUA_OK && status != ffi::LUA_YIELD { + ThreadStatus::Error + } else if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 { + ThreadStatus::Resumable + } else { + ThreadStatus::Unresumable + } + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::{Thread, ThreadStatus}; + use error::Error; + use function::Function; + use lua::Lua; + + #[test] + fn test_thread() { + let lua = Lua::new(); + let thread = lua.create_thread( + lua.eval::( + r#" + function (s) + local sum = s + for i = 1,4 do + sum = sum + coroutine.yield(sum) + end + return sum + end + "#, + None, + ).unwrap(), + ).unwrap(); + + assert_eq!(thread.status(), ThreadStatus::Resumable); + assert_eq!(thread.resume::<_, i64>(0).unwrap(), 0); + assert_eq!(thread.status(), ThreadStatus::Resumable); + assert_eq!(thread.resume::<_, i64>(1).unwrap(), 1); + assert_eq!(thread.status(), ThreadStatus::Resumable); + assert_eq!(thread.resume::<_, i64>(2).unwrap(), 3); + assert_eq!(thread.status(), ThreadStatus::Resumable); + assert_eq!(thread.resume::<_, i64>(3).unwrap(), 6); + assert_eq!(thread.status(), ThreadStatus::Resumable); + assert_eq!(thread.resume::<_, i64>(4).unwrap(), 10); + assert_eq!(thread.status(), ThreadStatus::Unresumable); + + let accumulate = lua.create_thread( + lua.eval::( + r#" + function (sum) + while true do + sum = sum + coroutine.yield(sum) + end + end + "#, + None, + ).unwrap(), + ).unwrap(); + + for i in 0..4 { + accumulate.resume::<_, ()>(i).unwrap(); + } + assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), 10); + assert_eq!(accumulate.status(), ThreadStatus::Resumable); + assert!(accumulate.resume::<_, ()>("error").is_err()); + assert_eq!(accumulate.status(), ThreadStatus::Error); + + let thread = lua.eval::( + r#" + coroutine.create(function () + while true do + coroutine.yield(42) + end + end) + "#, + None, + ).unwrap(); + assert_eq!(thread.status(), ThreadStatus::Resumable); + assert_eq!(thread.resume::<_, i64>(()).unwrap(), 42); + + let thread: Thread = lua.eval( + r#" + coroutine.create(function(arg) + assert(arg == 42) + local yieldarg = coroutine.yield(123) + assert(yieldarg == 43) + return 987 + end) + "#, + None, + ).unwrap(); + + assert_eq!(thread.resume::<_, u32>(42).unwrap(), 123); + assert_eq!(thread.resume::<_, u32>(43).unwrap(), 987); + + match thread.resume::<_, u32>(()) { + Err(Error::CoroutineInactive) => {} + Err(_) => panic!("resuming dead coroutine error is not CoroutineInactive kind"), + _ => panic!("resuming dead coroutine did not return error"), + } + } + + #[test] + fn coroutine_from_closure() { + let lua = Lua::new(); + let thrd_main = lua.create_function(|_, ()| Ok(())).unwrap(); + lua.globals().set("main", thrd_main).unwrap(); + let thrd: Thread = lua.eval("coroutine.create(main)", None).unwrap(); + thrd.resume::<_, ()>(()).unwrap(); + } + + #[test] + #[should_panic] + fn coroutine_panic() { + let lua = Lua::new(); + let thrd_main = lua.create_function(|lua, ()| { + // whoops, 'main' has a wrong type + let _coro: u32 = lua.globals().get("main").unwrap(); + Ok(()) + }).unwrap(); + lua.globals().set("main", thrd_main.clone()).unwrap(); + let thrd: Thread = lua.create_thread(thrd_main).unwrap(); + thrd.resume::<_, ()>(()).unwrap(); + } +} diff --git a/src/types.rs b/src/types.rs index a9cf2be..6d406ad 100644 --- a/src/types.rs +++ b/src/types.rs @@ -49,7 +49,8 @@ impl<'lua> Drop for LuaRef<'lua> { #[cfg(test)] mod tests { use super::LightUserData; - use lua::{Function, Lua}; + use function::Function; + use lua::Lua; use std::os::raw::c_void; diff --git a/src/userdata.rs b/src/userdata.rs index a07c6b4..48955bb 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -392,7 +392,8 @@ mod tests { use super::{MetaMethod, UserData, UserDataMethods}; use error::ExternalError; use string::String; - use lua::{Function, Lua}; + use function::Function; + use lua::Lua; #[test] fn test_user_data() { diff --git a/src/util.rs b/src/util.rs index 3507cc4..852155b 100644 --- a/src/util.rs +++ b/src/util.rs @@ -204,9 +204,7 @@ pub unsafe fn pop_error(state: *mut ffi::lua_State, err_code: c_int) -> Error { eprintln!("impossible Lua allocation error, aborting!"); process::abort() } - ffi::LUA_ERRGCMM => { - Error::GarbageCollectorError(err_string) - } + ffi::LUA_ERRGCMM => Error::GarbageCollectorError(err_string), _ => lua_panic!(state, "internal error: unrecognized lua error code"), } }