Support getting and setting environment for Lua functions.

Closes #218
This commit is contained in:
Alex Orlenko 2023-06-03 12:56:49 +01:00
parent 9785722d61
commit 1dc32452e6
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
2 changed files with 166 additions and 1 deletions

View File

@ -1,4 +1,5 @@
use std::cell::RefCell;
use std::ffi::CStr;
use std::mem;
use std::os::raw::{c_int, c_void};
use std::ptr;
@ -7,6 +8,7 @@ use std::slice;
use crate::error::{Error, Result};
use crate::lua::Lua;
use crate::memory::MemoryState;
use crate::table::Table;
use crate::types::{Callback, LuaRef, MaybeSend};
use crate::util::{
assert_stack, check_stack, error_traceback, pop_error, ptr_to_cstr_bytes, StackGuard,
@ -289,6 +291,109 @@ impl<'lua> Function<'lua> {
.call((self.clone(), args_wrapper))
}
/// Returns the environment of the Lua function.
///
/// By default Lua functions shares a global environment.
///
/// This function always returns `None` for Rust/C functions.
pub fn environment(&self) -> Option<Table> {
let lua = self.0.lua;
let state = lua.state();
unsafe {
let _sg = StackGuard::new(state);
assert_stack(state, 1);
lua.push_ref(&self.0);
let mut ar: ffi::lua_Debug = mem::zeroed();
#[cfg(not(feature = "luau"))]
{
ffi::lua_pushvalue(state, -1);
ffi::lua_getinfo(state, cstr!(">S"), &mut ar);
}
#[cfg(feature = "luau")]
ffi::lua_getinfo(state, -1, cstr!("s"), &mut ar);
if ptr_to_cstr_bytes(ar.what) == Some(b"C") {
return None;
}
#[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))]
ffi::lua_getfenv(state, -1);
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
for i in 1..=255 {
// Traverse upvalues until we find the _ENV one
match ffi::lua_getupvalue(state, -1, i) {
s if s.is_null() => break,
s if CStr::from_ptr(s as _).to_bytes() == b"_ENV" => break,
_ => ffi::lua_pop(state, 1),
}
}
if ffi::lua_type(state, -1) != ffi::LUA_TTABLE {
return None;
}
Some(Table(lua.pop_ref()))
}
}
/// Sets the environment of the Lua function.
///
/// The environment is a table that is used as the global environment for the function.
/// Returns `true` if environment successfully changed, `false` otherwise.
///
/// This function does nothing for Rust/C functions.
pub fn set_environment(&self, env: Table) -> Result<bool> {
let lua = self.0.lua;
let state = lua.state();
unsafe {
let _sg = StackGuard::new(state);
check_stack(state, 2)?;
lua.push_ref(&self.0);
let mut ar: ffi::lua_Debug = mem::zeroed();
#[cfg(not(feature = "luau"))]
{
ffi::lua_pushvalue(state, -1);
ffi::lua_getinfo(state, cstr!(">S"), &mut ar);
}
#[cfg(feature = "luau")]
ffi::lua_getinfo(state, -1, cstr!("s"), &mut ar);
if ptr_to_cstr_bytes(ar.what) == Some(b"C") {
return Ok(false);
}
#[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))]
{
lua.push_ref(&env.0);
ffi::lua_setfenv(state, -2);
}
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
for i in 1..=255 {
match ffi::lua_getupvalue(state, -1, i) {
s if s.is_null() => return Ok(false),
s if CStr::from_ptr(s as _).to_bytes() == b"_ENV" => {
ffi::lua_pop(state, 1);
// Create an anonymous function with the new environment
let f_with_env = lua
.load("return _ENV")
.set_environment(env)
.try_cache()
.into_function()?;
lua.push_ref(&f_with_env.0);
ffi::lua_upvaluejoin(state, -2, i, -1, 1);
break;
}
_ => ffi::lua_pop(state, 1),
}
}
Ok(true)
}
}
/// Returns information about the function.
///
/// Corresponds to the `>Sn` what mask for [`lua_getinfo`] when applied to the function.

View File

@ -1,4 +1,4 @@
use mlua::{Function, Lua, Result, String};
use mlua::{Function, Lua, Result, String, Table};
#[test]
fn test_function() -> Result<()> {
@ -114,6 +114,66 @@ fn test_dump() -> Result<()> {
Ok(())
}
#[test]
fn test_function_environment() -> Result<()> {
let lua = Lua::new();
// We must not get or set environment for C functions
let rust_func = lua.create_function(|_, ()| Ok("hello"))?;
assert_eq!(rust_func.environment(), None);
assert_eq!(rust_func.set_environment(lua.globals()).ok(), Some(false));
// Test getting Lua function environment
lua.globals().set("hello", "global")?;
let lua_func = lua
.load(
r#"
local t = ""
return function()
-- two upvalues
return t .. hello
end
"#,
)
.eval::<Function>()?;
let lua_func2 = lua.load("return hello").into_function()?;
assert_eq!(lua_func.call::<_, String>(())?, "global");
assert_eq!(lua_func.environment(), Some(lua.globals()));
// Test changing the environment
let env = lua.create_table_from([("hello", "local")])?;
assert!(lua_func.set_environment(env.clone())?);
assert_eq!(lua_func.call::<_, String>(())?, "local");
assert_eq!(lua_func2.call::<_, String>(())?, "global");
// More complex case
lua.load(
r#"
local number = 15
function lucky() return tostring("number is "..number) end
new_env = {
tostring = function() return tostring(number) end,
}
"#,
)
.exec()?;
let lucky = lua.globals().get::<_, Function>("lucky")?;
assert_eq!(lucky.call::<_, String>(())?, "number is 15");
let new_env = lua.globals().get::<_, Table>("new_env")?;
lucky.set_environment(new_env)?;
assert_eq!(lucky.call::<_, String>(())?, "15");
// Test inheritance
let lua_func2 = lua
.load(r#"return function() return (function() return hello end)() end"#)
.eval::<Function>()?;
assert!(lua_func2.set_environment(env.clone())?);
lua.gc_collect()?;
assert_eq!(lua_func2.call::<_, String>(())?, "local");
Ok(())
}
#[test]
fn test_function_info() -> Result<()> {
let lua = Lua::new();