diff --git a/src/lua.rs b/src/lua.rs index 2460657..9737853 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -444,27 +444,34 @@ impl<'lua> LuaThread<'lua> { lua.push_ref(lua.state, &self.0); let thread_state = ffi::lua_tothread(lua.state, -1); - ffi::lua_pop(lua.state, 1); - let args = args.to_lua_multi(lua)?; - let nargs = args.len() as c_int; - check_stack(thread_state, nargs)?; + let status = ffi::lua_status(thread_state); + if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 { + ffi::lua_pop(lua.state, 1); - for arg in args { - lua.push_value(thread_state, arg)?; + let args = args.to_lua_multi(lua)?; + let nargs = args.len() as c_int; + check_stack(thread_state, nargs)?; + + for arg in args { + lua.push_value(thread_state, arg)?; + } + + handle_error( + lua.state, + resume_with_traceback(thread_state, lua.state, nargs), + )?; + + let nresults = ffi::lua_gettop(thread_state); + let mut results = LuaMultiValue::new(); + for _ in 0..nresults { + results.push_front(lua.pop_value(thread_state)?); + } + R::from_lua_multi(results, lua).map(Some) + } else { + ffi::lua_pop(lua.state, 1); + Ok(None) } - - handle_error( - lua.state, - resume_with_traceback(thread_state, lua.state, nargs), - )?; - - let nresults = ffi::lua_gettop(thread_state); - let mut results = LuaMultiValue::new(); - for _ in 0..nresults { - results.push_front(lua.pop_value(thread_state)?); - } - R::from_lua_multi(results, lua).map(Some) }) } } diff --git a/src/tests.rs b/src/tests.rs index 13aa7d1..58493b5 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -524,13 +524,15 @@ fn test_thread() { let lua = Lua::new(); let thread = lua.create_thread( lua.eval::( - r#"function (s) - local sum = s - for i = 1,4 do - sum = sum + coroutine.yield(sum) - end - return sum - end"#, + r#" + function (s) + local sum = s + for i = 1,4 do + sum = sum + coroutine.yield(sum) + end + return sum + end + "#, ).unwrap(), ).unwrap(); @@ -548,11 +550,13 @@ fn test_thread() { let accumulate = lua.create_thread( lua.eval::( - r#"function (sum) - while true do - sum = sum + coroutine.yield(sum) - end - end"#, + r#" + function (sum) + while true do + sum = sum + coroutine.yield(sum) + end + end + "#, ).unwrap(), ).unwrap(); @@ -565,14 +569,31 @@ fn test_thread() { assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Error); let thread = lua.eval::( - r#"coroutine.create(function () - while true do - coroutine.yield(42) - end - end)"#, + r#" + coroutine.create(function () + while true do + coroutine.yield(42) + end + end) + "#, ).unwrap(); assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); assert_eq!(thread.resume::<_, i64>(()).unwrap(), Some(42)); + + let thread: LuaThread = lua.eval( + r#" + coroutine.create(function(arg) + assert(arg == 42) + local yieldarg = coroutine.yield(123) + assert(yieldarg == 43) + return 987 + end) + "#, + ).unwrap(); + + assert_eq!(thread.resume::<_, u32>(42).unwrap(), Some(123)); + assert_eq!(thread.resume::<_, u32>(43).unwrap(), Some(987)); + assert_eq!(thread.resume::<_, u32>(()).unwrap(), None); } #[test] @@ -580,9 +601,11 @@ fn test_lightuserdata() { let lua = Lua::new(); let globals = lua.globals().unwrap(); lua.load::<()>( - r#"function id(a) - return a - end"#, + r#" + function id(a) + return a + end + "#, None, ).unwrap(); let res = globals @@ -632,7 +655,9 @@ fn test_result_conversions() { let globals = lua.globals().unwrap(); let err = lua.create_function(|lua, _| { - lua.pack(Result::Err::("only through failure can we succeed".to_string())) + lua.pack(Result::Err::( + "only through failure can we succeed".to_string(), + )) }).unwrap(); let ok = lua.create_function(|lua, _| { lua.pack(Result::Ok::("!".to_string()))