Slightly more palatable coroutine API

This commit is contained in:
kyren 2017-06-17 23:50:40 -04:00
parent 4d34ccdba1
commit 56a30fba2e
3 changed files with 51 additions and 42 deletions

View File

@ -39,6 +39,9 @@ error_chain! {
IncompleteStatement(err: String) {
display("Incomplete lua statement {}", err)
}
CoroutineInactive {
display("Cannot resume inactive coroutine")
}
}
foreign_links {

View File

@ -512,14 +512,16 @@ impl<'lua> LuaThread<'lua> {
/// are passed to its main function.
///
/// If the thread is no longer in `Active` state (meaning it has finished execution or
/// encountered an error), returns `None`. Otherwise, returns `Some` as follows:
/// encountered an error), this will return Err(CoroutineInactive),
/// otherwise will return Ok as follows:
///
/// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread
/// `return`s values from its main function, returns those.
pub fn resume<A: ToLuaMulti<'lua>, R: FromLuaMulti<'lua>>(
&self,
args: A,
) -> LuaResult<Option<R>> {
pub fn resume<A, R>(&self, args: A) -> LuaResult<R>
where
A: ToLuaMulti<'lua>,
R: FromLuaMulti<'lua>,
{
let lua = self.0.lua;
unsafe {
stack_guard(lua.state, 0, || {
@ -529,7 +531,10 @@ impl<'lua> LuaThread<'lua> {
let thread_state = ffi::lua_tothread(lua.state, -1);
let status = ffi::lua_status(thread_state);
if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 {
if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 {
return Err(LuaErrorKind::CoroutineInactive.into());
}
ffi::lua_pop(lua.state, 1);
let args = args.to_lua_multi(lua)?;
@ -550,11 +555,7 @@ impl<'lua> LuaThread<'lua> {
for _ in 0..nresults {
results.push_front(lua.pop_value(thread_state)?);
}
R::from_lua_multi(results, lua).map(Some)
} else {
ffi::lua_pop(lua.state, 1);
Ok(None)
}
R::from_lua_multi(results, lua)
})
}
}

View File

@ -458,12 +458,12 @@ fn test_error() {
match lua_error.call::<_, ()>(()) {
Err(LuaError(LuaErrorKind::ScriptError(_), _)) => {}
Err(_) => panic!("error is not ScriptError kind"),
_ => panic!("error not thrown"),
_ => panic!("error not returned"),
}
match rust_error.call::<_, ()>(()) {
Err(LuaError(LuaErrorKind::CallbackError(_), _)) => {}
Err(_) => panic!("error is not CallbackError kind"),
_ => panic!("error not thrown"),
_ => panic!("error not returned"),
}
test_pcall.call::<_, ()>(()).unwrap();
@ -537,15 +537,15 @@ fn test_thread() {
).unwrap();
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(0).unwrap(), Some(0));
assert_eq!(thread.resume::<_, i64>(0).unwrap(), 0);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(1).unwrap(), Some(1));
assert_eq!(thread.resume::<_, i64>(1).unwrap(), 1);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(2).unwrap(), Some(3));
assert_eq!(thread.resume::<_, i64>(2).unwrap(), 3);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(3).unwrap(), Some(6));
assert_eq!(thread.resume::<_, i64>(3).unwrap(), 6);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(4).unwrap(), Some(10));
assert_eq!(thread.resume::<_, i64>(4).unwrap(), 10);
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Dead);
let accumulate = lua.create_thread(
@ -563,7 +563,7 @@ fn test_thread() {
for i in 0..4 {
accumulate.resume::<_, ()>(i).unwrap();
}
assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), Some(10));
assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), 10);
assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Active);
assert!(accumulate.resume::<_, ()>("error").is_err());
assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Error);
@ -578,7 +578,7 @@ fn test_thread() {
"#,
).unwrap();
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(()).unwrap(), Some(42));
assert_eq!(thread.resume::<_, i64>(()).unwrap(), 42);
let thread: LuaThread = lua.eval(
r#"
@ -591,9 +591,14 @@ fn test_thread() {
"#,
).unwrap();
assert_eq!(thread.resume::<_, u32>(42).unwrap(), Some(123));
assert_eq!(thread.resume::<_, u32>(43).unwrap(), Some(987));
assert_eq!(thread.resume::<_, u32>(()).unwrap(), None);
assert_eq!(thread.resume::<_, u32>(42).unwrap(), 123);
assert_eq!(thread.resume::<_, u32>(43).unwrap(), 987);
match thread.resume::<_, u32>(()) {
Err(LuaError(LuaErrorKind::CoroutineInactive, _)) => {}
Err(_) => panic!("resuming dead coroutine error is not CoroutineInactive kind"),
_ => panic!("resuming dead coroutine did not return error"),
}
}
#[test]