#![allow(unused_imports)] use std::panic::catch_unwind; use std::rc::Rc; use std::time::Duration; use futures_executor::block_on; use futures_util::stream::TryStreamExt; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; #[test] fn test_thread() -> Result<()> { let lua = Lua::new(); let thread = lua.create_thread( lua.load( r#" function (s) local sum = s for i = 1,4 do sum = sum + coroutine.yield(sum) end return sum end "#, ) .eval()?, )?; assert_eq!(thread.status(), ThreadStatus::Resumable); assert_eq!(thread.resume::<_, i64>(0)?, 0); assert_eq!(thread.status(), ThreadStatus::Resumable); assert_eq!(thread.resume::<_, i64>(1)?, 1); assert_eq!(thread.status(), ThreadStatus::Resumable); assert_eq!(thread.resume::<_, i64>(2)?, 3); assert_eq!(thread.status(), ThreadStatus::Resumable); assert_eq!(thread.resume::<_, i64>(3)?, 6); assert_eq!(thread.status(), ThreadStatus::Resumable); assert_eq!(thread.resume::<_, i64>(4)?, 10); assert_eq!(thread.status(), ThreadStatus::Unresumable); let accumulate = lua.create_thread( lua.load( r#" function (sum) while true do sum = sum + coroutine.yield(sum) end end "#, ) .eval::()?, )?; for i in 0..4 { accumulate.resume::<_, ()>(i)?; } assert_eq!(accumulate.resume::<_, i64>(4)?, 10); assert_eq!(accumulate.status(), ThreadStatus::Resumable); assert!(accumulate.resume::<_, ()>("error").is_err()); assert_eq!(accumulate.status(), ThreadStatus::Error); let thread = lua .load( r#" coroutine.create(function () while true do coroutine.yield(42) end end) "#, ) .eval::()?; assert_eq!(thread.status(), ThreadStatus::Resumable); assert_eq!(thread.resume::<_, i64>(())?, 42); let thread: Thread = lua .load( r#" coroutine.create(function(arg) assert(arg == 42) local yieldarg = coroutine.yield(123) assert(yieldarg == 43) return 987 end) "#, ) .eval()?; assert_eq!(thread.resume::<_, u32>(42)?, 123); assert_eq!(thread.resume::<_, u32>(43)?, 987); match thread.resume::<_, u32>(()) { Err(Error::CoroutineInactive) => {} Err(_) => panic!("resuming dead coroutine error is not CoroutineInactive kind"), _ => panic!("resuming dead coroutine did not return error"), } Ok(()) } #[cfg(feature = "async")] #[tokio::test] async fn test_thread_stream() -> Result<()> { let lua = Lua::new(); let thread = lua.create_thread( lua.load( r#" function (s) local sum = s for i = 1,10 do sum = sum + i coroutine.yield(sum) end return sum end "#, ) .eval()?, )?; let mut s = thread.into_async::<_, i64>(0); let mut sum = 0; while let Some(n) = s.try_next().await? { sum += n; } assert_eq!(sum, 275); Ok(()) } #[test] fn coroutine_from_closure() -> Result<()> { let lua = Lua::new(); let thrd_main = lua.create_function(|_, ()| Ok(()))?; lua.globals().set("main", thrd_main)?; #[cfg(any(feature = "lua53", feature = "lua52", feature = "luajit"))] let thrd: Thread = lua.load("coroutine.create(main)").eval()?; #[cfg(feature = "lua51")] let thrd: Thread = lua .load("coroutine.create(function(...) return main(unpack(arg)) end)") .eval()?; thrd.resume::<_, ()>(())?; Ok(()) } #[test] fn coroutine_panic() { match catch_unwind(|| -> Result<()> { // check that coroutines propagate panics correctly let lua = Lua::new(); let thrd_main = lua.create_function(|_, ()| -> Result<()> { panic!("test_panic"); })?; lua.globals().set("main", thrd_main.clone())?; let thrd: Thread = lua.create_thread(thrd_main)?; thrd.resume(()) }) { Ok(r) => panic!("coroutine panic not propagated, instead returned {:?}", r), Err(p) => assert!(*p.downcast::<&str>().unwrap() == "test_panic"), } } #[cfg(feature = "async")] #[tokio::test] async fn test_thread_async() -> Result<()> { let lua = Lua::new(); let cnt = Rc::new(1); // sleep 1 second let cnt2 = cnt.clone(); let f = lua.create_async_function(move |_lua, ()| { let cnt3 = cnt2.clone(); async move { futures_timer::Delay::new(Duration::from_secs(*cnt3.as_ref())).await; Ok("hello") } })?; let mut thread_s = lua.create_thread(f)?.into_async(()); let val: String = thread_s.try_next().await?.unwrap_or_default(); // thread_s is non-resumable and subject to garbage collection lua.gc_collect()?; assert_eq!(Rc::strong_count(&cnt), 1); assert_eq!(val, "hello"); Ok(()) }