From 1fe583027bce76a4b980242a61ad641e6df30a16 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 4 Jul 2021 23:51:51 +0100 Subject: [PATCH] Add new functions: `lua.load_from_function()` and `lua.create_c_function()` This should be useful to register embedded C modules to Lua state. Provides a solution for #61 --- src/ffi/lua.rs | 5 ++-- src/lib.rs | 3 +-- src/lua.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++++ tests/function.rs | 17 ++++++++++++++ tests/tests.rs | 28 ++++++++++++++++++++++ 5 files changed, 108 insertions(+), 4 deletions(-) diff --git a/src/ffi/lua.rs b/src/ffi/lua.rs index 82516b4..ac0ba33 100644 --- a/src/ffi/lua.rs +++ b/src/ffi/lua.rs @@ -80,6 +80,7 @@ pub const LUA_ERRERR: c_int = 5; #[cfg(any(feature = "lua53", feature = "lua52"))] pub const LUA_ERRERR: c_int = 6; +/// A raw Lua Lua state associated with a thread. pub type lua_State = c_void; // basic types @@ -121,14 +122,14 @@ pub type lua_Number = luaconf::LUA_NUMBER; /// A Lua integer, usually equivalent to `i64`. pub type lua_Integer = luaconf::LUA_INTEGER; -// unsigned integer type +/// A Lua unsigned integer, usually equivalent to `u64`. pub type lua_Unsigned = luaconf::LUA_UNSIGNED; // type for continuation-function contexts #[cfg(any(feature = "lua54", feature = "lua53"))] pub type lua_KContext = luaconf::LUA_KCONTEXT; -/// Type for native functions that can be passed to Lua. +/// Type for native C functions that can be passed to Lua. pub type lua_CFunction = unsafe extern "C" fn(L: *mut lua_State) -> c_int; // Type for continuation functions diff --git a/src/lib.rs b/src/lib.rs index 80af7f4..91d002f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,8 +98,7 @@ mod userdata; mod util; mod value; -#[doc(hidden)] -pub use crate::ffi::lua_State; +pub use crate::{ffi::lua_CFunction, ffi::lua_State}; pub use crate::error::{Error, ExternalError, ExternalResult, Result}; pub use crate::function::Function; diff --git a/src/lua.rs b/src/lua.rs index 735485d..b81e442 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -510,6 +510,52 @@ impl Lua { res } + /// Calls the Lua function `func` with the string `modname` as an argument, sets + /// the call result to `package.loaded[modname]` and returns copy of the result. + /// + /// If `package.loaded[modname]` value is not nil, returns copy of the value without + /// calling the function. + /// + /// If the function does not return a non-nil value then this method assigns true to + /// `package.loaded[modname]`. + /// + /// Behavior is similar to Lua's [`require`] function. + /// + /// [`require`]: https://www.lua.org/manual/5.3/manual.html#pdf-require + pub fn load_from_function<'lua, S, T>( + &'lua self, + modname: &S, + func: Function<'lua>, + ) -> Result + where + S: AsRef<[u8]> + ?Sized, + T: FromLua<'lua>, + { + unsafe { + let _sg = StackGuard::new(self.state); + check_stack(self.state, 3)?; + + protect_lua(self.state, 0, 1, |state| { + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); + })?; + let loaded = Table(self.pop_ref()); + + let modname = self.create_string(modname)?; + let value = match loaded.raw_get(modname.clone())? { + Value::Nil => { + let result = match func.call(modname.clone())? { + Value::Nil => Value::Boolean(true), + res => res, + }; + loaded.raw_set(modname, result.clone())?; + result + } + res => res, + }; + T::from_lua(value, self) + } + } + /// Consumes and leaks `Lua` object, returning a static reference `&'static Lua`. /// /// This function is useful when the `Lua` object is supposed to live for the remainder @@ -1034,6 +1080,19 @@ impl Lua { }) } + /// Wraps a C function, creating a callable Lua function handle to it. + /// + /// # Safety + /// This function is unsafe because provides a way to execute unsafe C function. + pub unsafe fn create_c_function(&self, func: ffi::lua_CFunction) -> Result { + let _sg = StackGuard::new(self.state); + check_stack(self.state, 3)?; + protect_lua(self.state, 0, 1, |state| { + ffi::lua_pushcfunction(state, func); + })?; + Ok(Function(self.pop_ref())) + } + /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. /// /// While executing the function Rust will poll Future and if the result is not ready, call diff --git a/tests/function.rs b/tests/function.rs index b29fb81..a7bdd9f 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -76,6 +76,23 @@ fn test_rust_function() -> Result<()> { Ok(()) } +#[test] +fn test_c_function() -> Result<()> { + let lua = Lua::new(); + + unsafe extern "C" fn c_function(state: *mut mlua::lua_State) -> std::os::raw::c_int { + let lua = Lua::init_from_ptr(state); + lua.globals().set("c_function", true).unwrap(); + 0 + } + + let func = unsafe { lua.create_c_function(c_function)? }; + func.call(())?; + assert_eq!(lua.globals().get::<_, bool>("c_function")?, true); + + Ok(()) +} + #[test] fn test_dump() -> Result<()> { let lua = unsafe { Lua::unsafe_new() }; diff --git a/tests/tests.rs b/tests/tests.rs index d3c2d1a..9b450ae 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::iter::FromIterator; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::string::String as StdString; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::{error, f32, f64, fmt}; @@ -1086,3 +1087,30 @@ fn test_jit_version() -> Result<()> { .contains("LuaJIT")); Ok(()) } + +#[test] +fn test_load_from_function() -> Result<()> { + let lua = Lua::new(); + + let i = Arc::new(AtomicU32::new(0)); + let i2 = i.clone(); + let func = lua.create_function(move |lua, modname: String| { + i2.fetch_add(1, Ordering::Relaxed); + let t = lua.create_table()?; + t.set("__name", modname)?; + Ok(t) + })?; + + let t: Table = lua.load_from_function("my_module", func.clone())?; + assert_eq!(t.get::<_, String>("__name")?, "my_module"); + assert_eq!(i.load(Ordering::Relaxed), 1); + + let _: Value = lua.load_from_function("my_module", func)?; + assert_eq!(i.load(Ordering::Relaxed), 1); + + let func_nil = lua.create_function(move |_, _: String| Ok(Value::Nil))?; + let v: Value = lua.load_from_function("my_module2", func_nil)?; + assert_eq!(v, Value::Boolean(true)); + + Ok(()) +}