From 5293b8d6d2d5c197b4c5a2800fb08b344b6d3f86 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Wed, 5 May 2021 11:11:32 +0100 Subject: [PATCH] Add `Thread::reset()` for luajit/lua54 --- src/ffi/lua.rs | 18 +++++++++++++++++- src/ffi/mod.rs | 7 +++++-- src/thread.rs | 40 ++++++++++++++++++++++++++++++++++++++++ src/userdata.rs | 5 ++++- tests/thread.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 4 deletions(-) diff --git a/src/ffi/lua.rs b/src/ffi/lua.rs index 10dc971..705e0df 100644 --- a/src/ffi/lua.rs +++ b/src/ffi/lua.rs @@ -156,8 +156,13 @@ extern "C" { pub fn lua_newstate(f: lua_Alloc, ud: *mut c_void) -> *mut lua_State; pub fn lua_close(L: *mut lua_State); pub fn lua_newthread(L: *mut lua_State) -> *mut lua_State; + #[cfg(feature = "lua54")] - pub fn lua_resetthread(L: *mut lua_State) -> c_int; + #[link_name = "lua_resetthread"] + pub fn lua_resetthread_54(L: *mut lua_State) -> c_int; + #[cfg(all(feature = "luajit", feature = "vendored"))] + #[link_name = "lua_resetthread"] + pub fn lua_resetthread_jit(L: *mut lua_State, th: *mut lua_State); pub fn lua_atpanic(L: *mut lua_State, panicf: lua_CFunction) -> lua_CFunction; @@ -216,6 +221,17 @@ extern "C" { pub fn lua_topointer(L: *mut lua_State, idx: c_int) -> *const c_void; } +#[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored")))] +pub unsafe fn lua_resetthread(_L: *mut lua_State, th: *mut lua_State) -> c_int { + #[cfg(all(feature = "luajit", feature = "vendored"))] + { + lua_resetthread_jit(_L, th); + LUA_OK + } + #[cfg(feature = "lua54")] + lua_resetthread_54(th) +} + // Comparison and arithmetic functions pub const LUA_OPADD: c_int = 0; pub const LUA_OPSUB: c_int = 1; diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 462860b..87bfa9b 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -160,8 +160,8 @@ pub use self::lua::{ #[cfg(feature = "lua54")] pub use self::lua::{ - lua_getiuservalue, lua_newuserdatauv, lua_resetthread, lua_setcstacklimit, lua_setiuservalue, - lua_setwarnf, lua_toclose, lua_warning, + lua_getiuservalue, lua_newuserdatauv, lua_setcstacklimit, lua_setiuservalue, lua_setwarnf, + lua_toclose, lua_warning, }; #[cfg(any(feature = "lua54", feature = "lua53"))] @@ -170,6 +170,9 @@ pub use self::lua::{lua_isyieldable, lua_version}; #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] pub use self::lua::{lua_callk, lua_pcallk, lua_upvalueid, lua_upvaluejoin, lua_yieldk}; +#[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored")))] +pub use self::lua::lua_resetthread; + // auxiliary library types pub use self::lauxlib::luaL_Reg; diff --git a/src/thread.rs b/src/thread.rs index 1da71ed..20692bd 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -7,6 +7,9 @@ use crate::types::LuaRef; use crate::util::{assert_stack, check_stack, pop_error, StackGuard}; use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti}; +#[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored"), doc))] +use crate::function::Function; + #[cfg(feature = "async")] use { crate::{ @@ -170,6 +173,43 @@ impl<'lua> Thread<'lua> { } } + /// Resets a thread + /// + /// In [Lua 5.4]: cleans its call stack and closes all pending to-be-closed variables. + /// Returns a error in case of either the original error that stopped the thread or errors + /// in closing methods. + /// + /// In [LuaJIT]: resets to the initial state of a newly created Lua thread. + /// Lua threads in arbitrary states (like yielded or errored) can be reset properly. + /// + /// Sets a Lua function for the thread afterwards. + /// + /// Requires `feature = "lua54"` OR `feature = "luajit,vendored"` + /// + /// [Lua 5.4]: https://www.lua.org/manual/5.4/manual.html#lua_resetthread + /// [LuaJIT]: https://github.com/openresty/luajit2#lua_resetthread + #[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored"), doc))] + pub fn reset(&self, func: Function<'lua>) -> Result<()> { + let lua = self.0.lua; + unsafe { + let _sg = StackGuard::new(lua.state); + check_stack(lua.state, 2)?; + + lua.push_ref(&self.0); + let thread_state = ffi::lua_tothread(lua.state, -1); + + let ret = ffi::lua_resetthread(lua.state, thread_state); + if ret != ffi::LUA_OK { + return Err(pop_error(thread_state, ret)); + } + + lua.push_ref(&func.0); + ffi::lua_xmove(lua.state, thread_state, 1); + + Ok(()) + } + } + /// Converts Thread to an AsyncThread which implements Future and Stream traits. /// /// `args` are passed as arguments to the thread function for first call. diff --git a/src/userdata.rs b/src/userdata.rs index 1fe62ab..133f518 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -21,7 +21,10 @@ use crate::lua::Lua; use crate::table::{Table, TablePairs}; use crate::types::{LuaRef, MaybeSend}; use crate::util::{check_stack, get_destructed_userdata_metatable, get_userdata, StackGuard}; -use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti, Value}; +use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti}; + +#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] +use crate::value::Value; /// Kinds of metamethods that can be overridden. /// diff --git a/tests/thread.rs b/tests/thread.rs index fd2f20c..ebb68a9 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -93,6 +93,49 @@ fn test_thread() -> Result<()> { Ok(()) } +#[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored")))] +#[test] +fn test_thread_reset() -> Result<()> { + use mlua::{AnyUserData, UserData}; + use std::sync::Arc; + + let lua = Lua::new(); + + struct MyUserData(Arc<()>); + impl UserData for MyUserData {} + + let arc = Arc::new(()); + + let func: Function = lua.load(r#"function(ud) coroutine.yield(ud) end"#).eval()?; + let thread = lua.create_thread(func.clone())?; + + for _ in 0..2 { + assert_eq!(thread.status(), ThreadStatus::Resumable); + let _ = thread.resume::<_, AnyUserData>(MyUserData(arc.clone()))?; + assert_eq!(thread.status(), ThreadStatus::Resumable); + assert_eq!(Arc::strong_count(&arc), 2); + thread.resume::<_, ()>(())?; + assert_eq!(thread.status(), ThreadStatus::Unresumable); + thread.reset(func.clone())?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&arc), 1); + } + + // Check for errors (Lua 5.4 only) + #[cfg(feature = "lua54")] + { + let func: Function = lua.load(r#"function(ud) error("test error") end"#).eval()?; + let thread = lua.create_thread(func.clone())?; + let _ = thread.resume::<_, AnyUserData>(MyUserData(arc.clone())); + assert_eq!(thread.status(), ThreadStatus::Error); + assert_eq!(Arc::strong_count(&arc), 2); + assert!(thread.reset(func.clone()).is_err()); + assert_eq!(thread.status(), ThreadStatus::Error); + } + + Ok(()) +} + #[test] fn coroutine_from_closure() -> Result<()> { let lua = Lua::new();