From 1dc32452e61326073da78551ca968e1b1db6cb82 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sat, 3 Jun 2023 12:56:49 +0100 Subject: [PATCH] Support getting and setting environment for Lua functions. Closes #218 --- src/function.rs | 105 ++++++++++++++++++++++++++++++++++++++++++++++ tests/function.rs | 62 ++++++++++++++++++++++++++- 2 files changed, 166 insertions(+), 1 deletion(-) diff --git a/src/function.rs b/src/function.rs index 636194b..df6b952 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::ffi::CStr; use std::mem; use std::os::raw::{c_int, c_void}; use std::ptr; @@ -7,6 +8,7 @@ use std::slice; use crate::error::{Error, Result}; use crate::lua::Lua; use crate::memory::MemoryState; +use crate::table::Table; use crate::types::{Callback, LuaRef, MaybeSend}; use crate::util::{ assert_stack, check_stack, error_traceback, pop_error, ptr_to_cstr_bytes, StackGuard, @@ -289,6 +291,109 @@ impl<'lua> Function<'lua> { .call((self.clone(), args_wrapper)) } + /// Returns the environment of the Lua function. + /// + /// By default Lua functions shares a global environment. + /// + /// This function always returns `None` for Rust/C functions. + pub fn environment(&self) -> Option { + let lua = self.0.lua; + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 1); + + lua.push_ref(&self.0); + + let mut ar: ffi::lua_Debug = mem::zeroed(); + #[cfg(not(feature = "luau"))] + { + ffi::lua_pushvalue(state, -1); + ffi::lua_getinfo(state, cstr!(">S"), &mut ar); + } + #[cfg(feature = "luau")] + ffi::lua_getinfo(state, -1, cstr!("s"), &mut ar); + + if ptr_to_cstr_bytes(ar.what) == Some(b"C") { + return None; + } + + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + ffi::lua_getfenv(state, -1); + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] + for i in 1..=255 { + // Traverse upvalues until we find the _ENV one + match ffi::lua_getupvalue(state, -1, i) { + s if s.is_null() => break, + s if CStr::from_ptr(s as _).to_bytes() == b"_ENV" => break, + _ => ffi::lua_pop(state, 1), + } + } + + if ffi::lua_type(state, -1) != ffi::LUA_TTABLE { + return None; + } + Some(Table(lua.pop_ref())) + } + } + + /// Sets the environment of the Lua function. + /// + /// The environment is a table that is used as the global environment for the function. + /// Returns `true` if environment successfully changed, `false` otherwise. + /// + /// This function does nothing for Rust/C functions. + pub fn set_environment(&self, env: Table) -> Result { + let lua = self.0.lua; + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 2)?; + + lua.push_ref(&self.0); + + let mut ar: ffi::lua_Debug = mem::zeroed(); + #[cfg(not(feature = "luau"))] + { + ffi::lua_pushvalue(state, -1); + ffi::lua_getinfo(state, cstr!(">S"), &mut ar); + } + #[cfg(feature = "luau")] + ffi::lua_getinfo(state, -1, cstr!("s"), &mut ar); + + if ptr_to_cstr_bytes(ar.what) == Some(b"C") { + return Ok(false); + } + + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + { + lua.push_ref(&env.0); + ffi::lua_setfenv(state, -2); + } + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] + for i in 1..=255 { + match ffi::lua_getupvalue(state, -1, i) { + s if s.is_null() => return Ok(false), + s if CStr::from_ptr(s as _).to_bytes() == b"_ENV" => { + ffi::lua_pop(state, 1); + // Create an anonymous function with the new environment + let f_with_env = lua + .load("return _ENV") + .set_environment(env) + .try_cache() + .into_function()?; + lua.push_ref(&f_with_env.0); + ffi::lua_upvaluejoin(state, -2, i, -1, 1); + break; + } + _ => ffi::lua_pop(state, 1), + } + } + + Ok(true) + } + } + /// Returns information about the function. /// /// Corresponds to the `>Sn` what mask for [`lua_getinfo`] when applied to the function. diff --git a/tests/function.rs b/tests/function.rs index bc99ec0..234227e 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -1,4 +1,4 @@ -use mlua::{Function, Lua, Result, String}; +use mlua::{Function, Lua, Result, String, Table}; #[test] fn test_function() -> Result<()> { @@ -114,6 +114,66 @@ fn test_dump() -> Result<()> { Ok(()) } +#[test] +fn test_function_environment() -> Result<()> { + let lua = Lua::new(); + + // We must not get or set environment for C functions + let rust_func = lua.create_function(|_, ()| Ok("hello"))?; + assert_eq!(rust_func.environment(), None); + assert_eq!(rust_func.set_environment(lua.globals()).ok(), Some(false)); + + // Test getting Lua function environment + lua.globals().set("hello", "global")?; + let lua_func = lua + .load( + r#" + local t = "" + return function() + -- two upvalues + return t .. hello + end + "#, + ) + .eval::()?; + let lua_func2 = lua.load("return hello").into_function()?; + assert_eq!(lua_func.call::<_, String>(())?, "global"); + assert_eq!(lua_func.environment(), Some(lua.globals())); + + // Test changing the environment + let env = lua.create_table_from([("hello", "local")])?; + assert!(lua_func.set_environment(env.clone())?); + assert_eq!(lua_func.call::<_, String>(())?, "local"); + assert_eq!(lua_func2.call::<_, String>(())?, "global"); + + // More complex case + lua.load( + r#" + local number = 15 + function lucky() return tostring("number is "..number) end + new_env = { + tostring = function() return tostring(number) end, + } + "#, + ) + .exec()?; + let lucky = lua.globals().get::<_, Function>("lucky")?; + assert_eq!(lucky.call::<_, String>(())?, "number is 15"); + let new_env = lua.globals().get::<_, Table>("new_env")?; + lucky.set_environment(new_env)?; + assert_eq!(lucky.call::<_, String>(())?, "15"); + + // Test inheritance + let lua_func2 = lua + .load(r#"return function() return (function() return hello end)() end"#) + .eval::()?; + assert!(lua_func2.set_environment(env.clone())?); + lua.gc_collect()?; + assert_eq!(lua_func2.call::<_, String>(())?, "local"); + + Ok(()) +} + #[test] fn test_function_info() -> Result<()> { let lua = Lua::new();