Slightly more palatable coroutine API

This commit is contained in:
kyren 2017-06-17 23:50:40 -04:00
parent 4d34ccdba1
commit 56a30fba2e
3 changed files with 51 additions and 42 deletions

View File

@ -39,6 +39,9 @@ error_chain! {
IncompleteStatement(err: String) { IncompleteStatement(err: String) {
display("Incomplete lua statement {}", err) display("Incomplete lua statement {}", err)
} }
CoroutineInactive {
display("Cannot resume inactive coroutine")
}
} }
foreign_links { foreign_links {

View File

@ -512,14 +512,16 @@ impl<'lua> LuaThread<'lua> {
/// are passed to its main function. /// are passed to its main function.
/// ///
/// If the thread is no longer in `Active` state (meaning it has finished execution or /// If the thread is no longer in `Active` state (meaning it has finished execution or
/// encountered an error), returns `None`. Otherwise, returns `Some` as follows: /// encountered an error), this will return Err(CoroutineInactive),
/// otherwise will return Ok as follows:
/// ///
/// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread
/// `return`s values from its main function, returns those. /// `return`s values from its main function, returns those.
pub fn resume<A: ToLuaMulti<'lua>, R: FromLuaMulti<'lua>>( pub fn resume<A, R>(&self, args: A) -> LuaResult<R>
&self, where
args: A, A: ToLuaMulti<'lua>,
) -> LuaResult<Option<R>> { R: FromLuaMulti<'lua>,
{
let lua = self.0.lua; let lua = self.0.lua;
unsafe { unsafe {
stack_guard(lua.state, 0, || { stack_guard(lua.state, 0, || {
@ -529,7 +531,10 @@ impl<'lua> LuaThread<'lua> {
let thread_state = ffi::lua_tothread(lua.state, -1); let thread_state = ffi::lua_tothread(lua.state, -1);
let status = ffi::lua_status(thread_state); let status = ffi::lua_status(thread_state);
if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 { if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 {
return Err(LuaErrorKind::CoroutineInactive.into());
}
ffi::lua_pop(lua.state, 1); ffi::lua_pop(lua.state, 1);
let args = args.to_lua_multi(lua)?; let args = args.to_lua_multi(lua)?;
@ -550,11 +555,7 @@ impl<'lua> LuaThread<'lua> {
for _ in 0..nresults { for _ in 0..nresults {
results.push_front(lua.pop_value(thread_state)?); results.push_front(lua.pop_value(thread_state)?);
} }
R::from_lua_multi(results, lua).map(Some) R::from_lua_multi(results, lua)
} else {
ffi::lua_pop(lua.state, 1);
Ok(None)
}
}) })
} }
} }

View File

@ -458,12 +458,12 @@ fn test_error() {
match lua_error.call::<_, ()>(()) { match lua_error.call::<_, ()>(()) {
Err(LuaError(LuaErrorKind::ScriptError(_), _)) => {} Err(LuaError(LuaErrorKind::ScriptError(_), _)) => {}
Err(_) => panic!("error is not ScriptError kind"), Err(_) => panic!("error is not ScriptError kind"),
_ => panic!("error not thrown"), _ => panic!("error not returned"),
} }
match rust_error.call::<_, ()>(()) { match rust_error.call::<_, ()>(()) {
Err(LuaError(LuaErrorKind::CallbackError(_), _)) => {} Err(LuaError(LuaErrorKind::CallbackError(_), _)) => {}
Err(_) => panic!("error is not CallbackError kind"), Err(_) => panic!("error is not CallbackError kind"),
_ => panic!("error not thrown"), _ => panic!("error not returned"),
} }
test_pcall.call::<_, ()>(()).unwrap(); test_pcall.call::<_, ()>(()).unwrap();
@ -537,15 +537,15 @@ fn test_thread() {
).unwrap(); ).unwrap();
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(0).unwrap(), Some(0)); assert_eq!(thread.resume::<_, i64>(0).unwrap(), 0);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(1).unwrap(), Some(1)); assert_eq!(thread.resume::<_, i64>(1).unwrap(), 1);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(2).unwrap(), Some(3)); assert_eq!(thread.resume::<_, i64>(2).unwrap(), 3);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(3).unwrap(), Some(6)); assert_eq!(thread.resume::<_, i64>(3).unwrap(), 6);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(4).unwrap(), Some(10)); assert_eq!(thread.resume::<_, i64>(4).unwrap(), 10);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Dead); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Dead);
let accumulate = lua.create_thread( let accumulate = lua.create_thread(
@ -563,7 +563,7 @@ fn test_thread() {
for i in 0..4 { for i in 0..4 {
accumulate.resume::<_, ()>(i).unwrap(); accumulate.resume::<_, ()>(i).unwrap();
} }
assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), Some(10)); assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), 10);
assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Active); assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Active);
assert!(accumulate.resume::<_, ()>("error").is_err()); assert!(accumulate.resume::<_, ()>("error").is_err());
assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Error); assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Error);
@ -578,7 +578,7 @@ fn test_thread() {
"#, "#,
).unwrap(); ).unwrap();
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(()).unwrap(), Some(42)); assert_eq!(thread.resume::<_, i64>(()).unwrap(), 42);
let thread: LuaThread = lua.eval( let thread: LuaThread = lua.eval(
r#" r#"
@ -591,9 +591,14 @@ fn test_thread() {
"#, "#,
).unwrap(); ).unwrap();
assert_eq!(thread.resume::<_, u32>(42).unwrap(), Some(123)); assert_eq!(thread.resume::<_, u32>(42).unwrap(), 123);
assert_eq!(thread.resume::<_, u32>(43).unwrap(), Some(987)); assert_eq!(thread.resume::<_, u32>(43).unwrap(), 987);
assert_eq!(thread.resume::<_, u32>(()).unwrap(), None);
match thread.resume::<_, u32>(()) {
Err(LuaError(LuaErrorKind::CoroutineInactive, _)) => {}
Err(_) => panic!("resuming dead coroutine error is not CoroutineInactive kind"),
_ => panic!("resuming dead coroutine did not return error"),
}
} }
#[test] #[test]