From 56a30fba2e96481c50f1178956dfcb153d5f8bad Mon Sep 17 00:00:00 2001 From: kyren Date: Sat, 17 Jun 2017 23:50:40 -0400 Subject: [PATCH] Slightly more palatable coroutine API --- src/error.rs | 3 +++ src/lua.rs | 61 ++++++++++++++++++++++++++-------------------------- src/tests.rs | 29 ++++++++++++++----------- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/src/error.rs b/src/error.rs index fa0f830..f117863 100644 --- a/src/error.rs +++ b/src/error.rs @@ -39,6 +39,9 @@ error_chain! { IncompleteStatement(err: String) { display("Incomplete lua statement {}", err) } + CoroutineInactive { + display("Cannot resume inactive coroutine") + } } foreign_links { diff --git a/src/lua.rs b/src/lua.rs index de872bd..c694b03 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -512,14 +512,16 @@ impl<'lua> LuaThread<'lua> { /// are passed to its main function. /// /// If the thread is no longer in `Active` state (meaning it has finished execution or - /// encountered an error), returns `None`. Otherwise, returns `Some` as follows: + /// encountered an error), this will return Err(CoroutineInactive), + /// otherwise will return Ok as follows: /// /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread /// `return`s values from its main function, returns those. - pub fn resume, R: FromLuaMulti<'lua>>( - &self, - args: A, - ) -> LuaResult> { + pub fn resume(&self, args: A) -> LuaResult + where + A: ToLuaMulti<'lua>, + R: FromLuaMulti<'lua>, + { let lua = self.0.lua; unsafe { stack_guard(lua.state, 0, || { @@ -529,32 +531,31 @@ impl<'lua> LuaThread<'lua> { let thread_state = ffi::lua_tothread(lua.state, -1); let status = ffi::lua_status(thread_state); - if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 { - ffi::lua_pop(lua.state, 1); - - let args = args.to_lua_multi(lua)?; - let nargs = args.len() as c_int; - check_stack(thread_state, nargs)?; - - for arg in args { - lua.push_value(thread_state, arg)?; - } - - handle_error( - lua.state, - resume_with_traceback(thread_state, lua.state, nargs), - )?; - - let nresults = ffi::lua_gettop(thread_state); - let mut results = LuaMultiValue::new(); - for _ in 0..nresults { - results.push_front(lua.pop_value(thread_state)?); - } - R::from_lua_multi(results, lua).map(Some) - } else { - ffi::lua_pop(lua.state, 1); - Ok(None) + if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 { + return Err(LuaErrorKind::CoroutineInactive.into()); } + + ffi::lua_pop(lua.state, 1); + + let args = args.to_lua_multi(lua)?; + let nargs = args.len() as c_int; + check_stack(thread_state, nargs)?; + + for arg in args { + lua.push_value(thread_state, arg)?; + } + + handle_error( + lua.state, + resume_with_traceback(thread_state, lua.state, nargs), + )?; + + let nresults = ffi::lua_gettop(thread_state); + let mut results = LuaMultiValue::new(); + for _ in 0..nresults { + results.push_front(lua.pop_value(thread_state)?); + } + R::from_lua_multi(results, lua) }) } } diff --git a/src/tests.rs b/src/tests.rs index 58493b5..265668a 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -458,12 +458,12 @@ fn test_error() { match lua_error.call::<_, ()>(()) { Err(LuaError(LuaErrorKind::ScriptError(_), _)) => {} Err(_) => panic!("error is not ScriptError kind"), - _ => panic!("error not thrown"), + _ => panic!("error not returned"), } match rust_error.call::<_, ()>(()) { Err(LuaError(LuaErrorKind::CallbackError(_), _)) => {} Err(_) => panic!("error is not CallbackError kind"), - _ => panic!("error not thrown"), + _ => panic!("error not returned"), } test_pcall.call::<_, ()>(()).unwrap(); @@ -537,15 +537,15 @@ fn test_thread() { ).unwrap(); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); - assert_eq!(thread.resume::<_, i64>(0).unwrap(), Some(0)); + assert_eq!(thread.resume::<_, i64>(0).unwrap(), 0); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); - assert_eq!(thread.resume::<_, i64>(1).unwrap(), Some(1)); + assert_eq!(thread.resume::<_, i64>(1).unwrap(), 1); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); - assert_eq!(thread.resume::<_, i64>(2).unwrap(), Some(3)); + assert_eq!(thread.resume::<_, i64>(2).unwrap(), 3); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); - assert_eq!(thread.resume::<_, i64>(3).unwrap(), Some(6)); + assert_eq!(thread.resume::<_, i64>(3).unwrap(), 6); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); - assert_eq!(thread.resume::<_, i64>(4).unwrap(), Some(10)); + assert_eq!(thread.resume::<_, i64>(4).unwrap(), 10); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Dead); let accumulate = lua.create_thread( @@ -563,7 +563,7 @@ fn test_thread() { for i in 0..4 { accumulate.resume::<_, ()>(i).unwrap(); } - assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), Some(10)); + assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), 10); assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Active); assert!(accumulate.resume::<_, ()>("error").is_err()); assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Error); @@ -578,7 +578,7 @@ fn test_thread() { "#, ).unwrap(); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); - assert_eq!(thread.resume::<_, i64>(()).unwrap(), Some(42)); + assert_eq!(thread.resume::<_, i64>(()).unwrap(), 42); let thread: LuaThread = lua.eval( r#" @@ -591,9 +591,14 @@ fn test_thread() { "#, ).unwrap(); - assert_eq!(thread.resume::<_, u32>(42).unwrap(), Some(123)); - assert_eq!(thread.resume::<_, u32>(43).unwrap(), Some(987)); - assert_eq!(thread.resume::<_, u32>(()).unwrap(), None); + assert_eq!(thread.resume::<_, u32>(42).unwrap(), 123); + assert_eq!(thread.resume::<_, u32>(43).unwrap(), 987); + + match thread.resume::<_, u32>(()) { + Err(LuaError(LuaErrorKind::CoroutineInactive, _)) => {} + Err(_) => panic!("resuming dead coroutine error is not CoroutineInactive kind"), + _ => panic!("resuming dead coroutine did not return error"), + } } #[test]