diff --git a/src/conversion.rs b/src/conversion.rs index 53ff838..4a94793 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -74,6 +74,21 @@ impl<'lua> FromLua<'lua> for LuaUserData<'lua> { } } +impl<'lua> ToLua<'lua> for LuaThread<'lua> { + fn to_lua(self, _: &'lua Lua) -> LuaResult> { + Ok(LuaValue::Thread(self)) + } +} + +impl<'lua> FromLua<'lua> for LuaThread<'lua> { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult> { + match value { + LuaValue::Thread(t) => Ok(t), + _ => Err("cannot convert lua value to thread".into()), + } + } +} + impl<'lua, T: LuaUserDataType> ToLua<'lua> for T { fn to_lua(self, lua: &'lua Lua) -> LuaResult> { lua.create_userdata(self).map(LuaValue::UserData) diff --git a/src/ffi.rs b/src/ffi.rs index 314fe5e..0bdf0d0 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -21,6 +21,7 @@ pub type lua_KFunction = unsafe extern "C" fn(state: *mut lua_State, pub type lua_CFunction = unsafe extern "C" fn(state: *mut lua_State) -> c_int; pub const LUA_OK: c_int = 0; +pub const LUA_YIELD: c_int = 1; pub const LUA_ERRRUN: c_int = 2; pub const LUA_ERRSYNTAX: c_int = 3; pub const LUA_ERRMEM: c_int = 4; @@ -58,6 +59,8 @@ extern "C" { ctx: lua_KContext, k: Option) -> c_int; + pub fn lua_resume(state: *mut lua_State, from: *mut lua_State, nargs: c_int) -> c_int; + pub fn lua_status(state: *mut lua_State) -> c_int; pub fn lua_pushnil(state: *mut lua_State); pub fn lua_pushvalue(state: *mut lua_State, index: c_int); @@ -73,6 +76,7 @@ extern "C" { pub fn lua_toboolean(state: *mut lua_State, index: c_int) -> c_int; pub fn lua_tonumberx(state: *mut lua_State, index: c_int, isnum: *mut c_int) -> lua_Number; pub fn lua_touserdata(state: *mut lua_State, index: c_int) -> *mut c_void; + pub fn lua_tothread(state: *mut lua_State, index: c_int) -> *mut lua_State; pub fn lua_gettop(state: *const lua_State) -> c_int; pub fn lua_settop(state: *mut lua_State, n: c_int); @@ -95,6 +99,7 @@ extern "C" { pub fn lua_createtable(state: *mut lua_State, narr: c_int, nrec: c_int); pub fn lua_newuserdata(state: *mut lua_State, size: usize) -> *mut c_void; + pub fn lua_newthread(state: *mut lua_State) -> *mut lua_State; pub fn lua_settable(state: *mut lua_State, index: c_int); pub fn lua_setmetatable(state: *mut lua_State, index: c_int); @@ -116,8 +121,8 @@ extern "C" { pub fn luaL_ref(state: *mut lua_State, table: c_int) -> c_int; pub fn luaL_unref(state: *mut lua_State, table: c_int, lref: c_int); pub fn luaL_checkstack(state: *mut lua_State, size: c_int, msg: *const c_char); - pub fn luaL_traceback(state: *mut lua_State, - push_state: *mut lua_State, + pub fn luaL_traceback(push_state: *mut lua_State, + state: *mut lua_State, msg: *const c_char, level: c_int); } diff --git a/src/lua.rs b/src/lua.rs index 72e44c9..2c79d8d 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -29,6 +29,7 @@ pub enum LuaValue<'lua> { Table(LuaTable<'lua>), Function(LuaFunction<'lua>), UserData(LuaUserData<'lua>), + Thread(LuaThread<'lua>), } pub use self::LuaValue::Nil as LuaNil; @@ -323,6 +324,76 @@ impl<'lua> LuaFunction<'lua> { } } +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum LuaThreadStatus { + Dead, + Active, + Error, +} + +/// Handle to an an internal lua coroutine +#[derive(Clone, Debug)] +pub struct LuaThread<'lua>(LuaRef<'lua>); + +impl<'lua> LuaThread<'lua> { + /// If this thread has yielded a value, will return Some, otherwise the thread is finished and + /// this will return None. + pub fn resume, R: FromLuaMulti<'lua>>(&self, + args: A) + -> LuaResult> { + let lua = self.0.lua; + unsafe { + stack_guard(lua.state, 0, || { + check_stack(lua.state, 1)?; + + lua.push_ref(&self.0); + let thread_state = ffi::lua_tothread(lua.state, -1); + ffi::lua_pop(lua.state, 1); + + let args = args.to_lua_multi(lua)?; + let nargs = args.len() as c_int; + check_stack(thread_state, nargs)?; + + for arg in args { + push_value(thread_state, arg)?; + } + + handle_error(lua.state, + resume_with_traceback(thread_state, lua.state, nargs))?; + + let nresults = ffi::lua_gettop(thread_state); + let mut results = LuaMultiValue::new(); + for _ in 0..nresults { + results.push_front(pop_value(thread_state, lua)?); + } + R::from_lua_multi(results, lua).map(|r| Some(r)) + }) + } + } + + pub fn status(&self) -> LuaResult { + let lua = self.0.lua; + unsafe { + stack_guard(lua.state, 0, || { + check_stack(lua.state, 1)?; + + lua.push_ref(&self.0); + let thread_state = ffi::lua_tothread(lua.state, -1); + ffi::lua_pop(lua.state, 1); + + let status = ffi::lua_status(thread_state); + if status != ffi::LUA_OK && status != ffi::LUA_YIELD { + Ok(LuaThreadStatus::Error) + } else if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 { + Ok(LuaThreadStatus::Active) + } else { + Ok(LuaThreadStatus::Dead) + } + }) + } + } +} + /// These are the metamethods that can be overridden using this API #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub enum LuaMetaMethod { @@ -538,7 +609,7 @@ impl Lua { ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) - .unwrap(); + .unwrap(); stack_guard(state, 0, || { ffi::lua_pushlightuserdata(state, @@ -558,7 +629,7 @@ impl Lua { ffi::lua_settable(state, ffi::LUA_REGISTRYINDEX); Ok(()) }) - .unwrap(); + .unwrap(); stack_guard(state, 0, || { ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); @@ -574,7 +645,7 @@ impl Lua { ffi::lua_pop(state, 1); Ok(()) }) - .unwrap(); + .unwrap(); Lua { state, @@ -690,6 +761,18 @@ impl Lua { self.create_callback_function(Box::new(func)) } + pub fn create_thread<'lua>(&'lua self, func: LuaFunction<'lua>) -> LuaResult> { + unsafe { + stack_guard(self.state, 0, move || { + check_stack(self.state, 1)?; + + let thread_state = ffi::lua_newthread(self.state); + push_ref(thread_state, &func.0); + Ok(LuaThread(self.pop_ref())) + }) + } + } + pub fn create_userdata(&self, data: T) -> LuaResult where T: LuaUserDataType { @@ -875,103 +958,6 @@ impl Lua { } } - unsafe fn push_value(&self, value: LuaValue) -> LuaResult<()> { - stack_guard(self.state, 1, move || { - match value { - LuaValue::Nil => { - ffi::lua_pushnil(self.state); - } - - LuaValue::Boolean(b) => { - ffi::lua_pushboolean(self.state, if b { 1 } else { 0 }); - } - - LuaValue::Integer(i) => { - ffi::lua_pushinteger(self.state, i); - } - - LuaValue::Number(n) => { - ffi::lua_pushnumber(self.state, n); - } - - LuaValue::String(s) => { - self.push_ref(&s.0); - } - - LuaValue::Table(t) => { - self.push_ref(&t.0); - } - - LuaValue::Function(f) => { - self.push_ref(&f.0); - } - - LuaValue::UserData(ud) => { - self.push_ref(&ud.0); - } - } - Ok(()) - }) - } - - unsafe fn pop_value(&self) -> LuaResult { - stack_guard(self.state, -1, || match ffi::lua_type(self.state, -1) { - ffi::LUA_TNIL => { - ffi::lua_pop(self.state, 1); - Ok(LuaNil) - } - - ffi::LUA_TBOOLEAN => { - let b = LuaValue::Boolean(ffi::lua_toboolean(self.state, -1) != 0); - ffi::lua_pop(self.state, 1); - Ok(b) - } - - ffi::LUA_TNUMBER => { - if ffi::lua_isinteger(self.state, -1) != 0 { - let i = LuaValue::Integer(ffi::lua_tointeger(self.state, -1)); - ffi::lua_pop(self.state, 1); - Ok(i) - } else { - let n = LuaValue::Number(ffi::lua_tonumber(self.state, -1)); - ffi::lua_pop(self.state, 1); - Ok(n) - } - } - - ffi::LUA_TSTRING => Ok(LuaValue::String(LuaString(self.pop_ref()))), - - ffi::LUA_TTABLE => Ok(LuaValue::Table(LuaTable(self.pop_ref()))), - - ffi::LUA_TFUNCTION => Ok(LuaValue::Function(LuaFunction(self.pop_ref()))), - - ffi::LUA_TUSERDATA => Ok(LuaValue::UserData(LuaUserData(self.pop_ref()))), - - _ => Err("Unsupported type in pop_value".into()), - }) - } - - fn push_ref(&self, lref: &LuaRef) { - unsafe { - assert_eq!(lref.lua.state, - self.state, - "Lua instance passed LuaValue created from a different Lua"); - ffi::lua_rawgeti(self.state, - ffi::LUA_REGISTRYINDEX, - lref.registry_id as ffi::lua_Integer); - } - } - - fn pop_ref(&self) -> LuaRef { - unsafe { - let registry_id = ffi::luaL_ref(self.state, ffi::LUA_REGISTRYINDEX); - LuaRef { - lua: self, - registry_id: registry_id, - } - } - } - unsafe fn userdata_metatable(&self) -> LuaResult { // Used if both an __index metamethod is set and regular methods, checks methods table // first, then __index metamethod. @@ -1078,6 +1064,22 @@ impl Lua { } }) } + + fn push_value(&self, value: LuaValue) -> LuaResult<()> { + unsafe { push_value(self.state, value) } + } + + fn push_ref(&self, lref: &LuaRef) { + unsafe { push_ref(self.state, lref) } + } + + fn pop_value(&self) -> LuaResult { + unsafe { pop_value(self.state, self) } + } + + fn pop_ref(&self) -> LuaRef { + unsafe { pop_ref(self.state, self) } + } } static LUA_USERDATA_REGISTRY_KEY: u8 = 0; @@ -1086,9 +1088,105 @@ static FUNCTION_METATABLE_REGISTRY_KEY: u8 = 0; // If the return code is not LUA_OK, pops the error off of the stack and returns Err. If the error // was actually a rust panic, clears the current lua stack and panics. unsafe fn handle_error(state: *mut ffi::lua_State, ret: c_int) -> LuaResult<()> { - if ret != ffi::LUA_OK { + if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { Err(pop_error(state)) } else { Ok(()) } } + +unsafe fn push_value(state: *mut ffi::lua_State, value: LuaValue) -> LuaResult<()> { + stack_guard(state, 1, move || { + match value { + LuaValue::Nil => { + ffi::lua_pushnil(state); + } + + LuaValue::Boolean(b) => { + ffi::lua_pushboolean(state, if b { 1 } else { 0 }); + } + + LuaValue::Integer(i) => { + ffi::lua_pushinteger(state, i); + } + + LuaValue::Number(n) => { + ffi::lua_pushnumber(state, n); + } + + LuaValue::String(s) => { + push_ref(state, &s.0); + } + + LuaValue::Table(t) => { + push_ref(state, &t.0); + } + + LuaValue::Function(f) => { + push_ref(state, &f.0); + } + + LuaValue::UserData(ud) => { + push_ref(state, &ud.0); + } + + LuaValue::Thread(t) => { + push_ref(state, &t.0); + } + } + Ok(()) + }) +} + +unsafe fn push_ref(state: *mut ffi::lua_State, lref: &LuaRef) { + ffi::lua_rawgeti(state, + ffi::LUA_REGISTRYINDEX, + lref.registry_id as ffi::lua_Integer); +} + +unsafe fn pop_value(state: *mut ffi::lua_State, lua: &Lua) -> LuaResult { + stack_guard(state, -1, || match ffi::lua_type(state, -1) { + ffi::LUA_TNIL => { + ffi::lua_pop(state, 1); + Ok(LuaNil) + } + + ffi::LUA_TBOOLEAN => { + let b = LuaValue::Boolean(ffi::lua_toboolean(state, -1) != 0); + ffi::lua_pop(state, 1); + Ok(b) + } + + ffi::LUA_TNUMBER => { + if ffi::lua_isinteger(state, -1) != 0 { + let i = LuaValue::Integer(ffi::lua_tointeger(state, -1)); + ffi::lua_pop(state, 1); + Ok(i) + } else { + let n = LuaValue::Number(ffi::lua_tonumber(state, -1)); + ffi::lua_pop(state, 1); + Ok(n) + } + } + + ffi::LUA_TSTRING => Ok(LuaValue::String(LuaString(pop_ref(state, lua)))), + + ffi::LUA_TTABLE => Ok(LuaValue::Table(LuaTable(pop_ref(state, lua)))), + + ffi::LUA_TFUNCTION => Ok(LuaValue::Function(LuaFunction(pop_ref(state, lua)))), + + ffi::LUA_TUSERDATA => Ok(LuaValue::UserData(LuaUserData(pop_ref(state, lua)))), + + ffi::LUA_TTHREAD => Ok(LuaValue::Thread(LuaThread(pop_ref(state, lua)))), + + _ => Err("Unsupported type in pop_value".into()), + }) +} + +unsafe fn pop_ref(state: *mut ffi::lua_State, lua: &Lua) -> LuaRef { + let registry_id = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); + LuaRef { + lua: lua, + registry_id: registry_id, + } +} diff --git a/src/tests.rs b/src/tests.rs index 2ec8a90..1e77def 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -20,7 +20,7 @@ fn test_load() { lua.load(r#" res = 'foo'..'bar' "#, - None) + None) .unwrap(); assert_eq!(lua.get::<_, String>("res").unwrap(), "foobar"); } @@ -56,7 +56,7 @@ fn test_table() { table2 = {} table3 = {1, 2, nil, 4, 5} "#, - None) + None) .unwrap(); let table1 = lua.get::<_, LuaTable>("table1").unwrap(); @@ -82,7 +82,7 @@ fn test_function() { return arg1 .. arg2 end "#, - None) + None) .unwrap(); let concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -102,7 +102,7 @@ fn test_bind() { return res end "#, - None) + None) .unwrap(); let mut concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -124,7 +124,7 @@ fn test_rust_function() { -- Test to make sure chunk return is ignored return 1 "#, - None) + None) .unwrap(); let lua_function = lua.get::<_, LuaFunction>("lua_function").unwrap(); @@ -182,7 +182,7 @@ fn test_methods() { return userdata:set_value(i) end "#, - None) + None) .unwrap(); let get = lua.get::<_, LuaFunction>("get_it").unwrap(); let set = lua.get::<_, LuaFunction>("set_it").unwrap(); @@ -238,7 +238,7 @@ fn test_scope() { tin = {1, 2, 3} } "#, - None) + None) .unwrap(); // Make sure that table gets do not borrow the table, but instead just borrow lua. @@ -276,7 +276,7 @@ fn test_lua_multi() { return 1, 2, 3, 4, 5, 6 end "#, - None) + None) .unwrap(); let concat = lua.get::<_, LuaFunction>("concat").unwrap(); @@ -299,7 +299,7 @@ fn test_coercion() { str = "123" num = 123.0 "#, - None) + None) .unwrap(); assert_eq!(lua.get::<_, String>("int").unwrap(), "123"); @@ -367,7 +367,7 @@ fn test_error() { understand_recursion() end "#, - None) + None) .unwrap(); let rust_error_function = @@ -404,7 +404,7 @@ fn test_error() { pcall(function () rust_panic_function() end) end "#, - None)?; + None)?; let rust_panic_function = lua.create_function(|_, _| { panic!("expected panic, this panic should be caught in rust") })?; @@ -426,7 +426,7 @@ fn test_error() { xpcall(function() rust_panic_function() end, function() end) end "#, - None)?; + None)?; let rust_panic_function = lua.create_function(|_, _| { panic!("expected panic, this panic should be caught in rust") })?; @@ -441,3 +441,54 @@ fn test_error() { Err(_) => {} }; } + +#[test] +fn test_thread() { + let lua = Lua::new(); + let thread = lua.create_thread(lua.eval::(r#"function (s) + local sum = s + for i = 1,4 do + sum = sum + coroutine.yield(sum) + end + return sum + end"#) + .unwrap()) + .unwrap(); + + assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); + assert_eq!(thread.resume::<_, i64>(0).unwrap(), Some(0)); + assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); + assert_eq!(thread.resume::<_, i64>(1).unwrap(), Some(1)); + assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); + assert_eq!(thread.resume::<_, i64>(2).unwrap(), Some(3)); + assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); + assert_eq!(thread.resume::<_, i64>(3).unwrap(), Some(6)); + assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); + assert_eq!(thread.resume::<_, i64>(4).unwrap(), Some(10)); + assert_eq!(thread.status().unwrap(), LuaThreadStatus::Dead); + + let accumulate = lua.create_thread(lua.eval::(r#"function (sum) + while true do + sum = sum + coroutine.yield(sum) + end + end"#) + .unwrap()) + .unwrap(); + + for i in 0..4 { + accumulate.resume::<_, ()>(i).unwrap(); + } + assert_eq!(accumulate.resume::<_, i64>(4).unwrap(), Some(10)); + assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Active); + assert!(accumulate.resume::<_, ()>("error").is_err()); + assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Error); + + let thread = lua.eval::(r#"coroutine.create(function () + while true do + coroutine.yield(42) + end + end)"#) + .unwrap(); + assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active); + assert_eq!(thread.resume::<_, i64>(()).unwrap(), Some(42)); +} diff --git a/src/util.rs b/src/util.rs index ad8bf60..a58a699 100644 --- a/src/util.rs +++ b/src/util.rs @@ -129,7 +129,8 @@ pub unsafe fn pop_error(state: *mut ffi::lua_State) -> LuaError { } else { ffi::lua_pop(state, 1); - LuaErrorKind::ScriptError("".to_owned()).into() + LuaErrorKind::ScriptError("".to_owned()) + .into() } } @@ -169,6 +170,34 @@ pub unsafe fn pcall_with_traceback(state: *mut ffi::lua_State, ret } +pub unsafe fn resume_with_traceback(state: *mut ffi::lua_State, + from: *mut ffi::lua_State, + nargs: c_int) + -> c_int { + let res = ffi::lua_resume(state, from, nargs); + if res != ffi::LUA_OK && res != ffi::LUA_YIELD { + if is_wrapped_error(state, 1) { + if !is_panic_error(state, 1) { + let error = pop_error(state); + ffi::luaL_traceback(from, state, ptr::null(), 0); + let traceback = CStr::from_ptr(ffi::lua_tolstring(from, 1, ptr::null_mut())) + .to_str() + .unwrap() + .to_owned(); + push_error(from, WrappedError::Error(LuaError::with_chain(error, LuaErrorKind::CallbackError(traceback)))); + } + } else { + let s = ffi::lua_tolstring(state, 1, ptr::null_mut()); + if !s.is_null() { + ffi::luaL_traceback(from, state, s, 0); + } else { + ffi::luaL_traceback(from, state, cstr!(""), 0); + } + } + } + res +} + // A variant of pcall that does not allow lua to catch panic errors from callback_error pub unsafe extern "C" fn safe_pcall(state: *mut ffi::lua_State) -> c_int { if ffi::lua_pcall(state, ffi::lua_gettop(state) - 1, ffi::LUA_MULTRET, 0) != ffi::LUA_OK {